tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_python.py
CHANGED
@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
|
|
7
7
|
from tinygrad.dtype import DType, dtypes, ImageDType
|
8
8
|
from tinygrad.helpers import all_same, getenv, flatten
|
9
9
|
from tinygrad.device import Compiled, Compiler, Allocator
|
10
|
-
from tinygrad.codegen.uops import
|
10
|
+
from tinygrad.codegen.uops import UOps, UOp
|
11
11
|
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
12
12
|
from tinygrad.renderer import Renderer
|
13
13
|
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
@@ -17,7 +17,7 @@ def _load(m, i):
|
|
17
17
|
return m[i]
|
18
18
|
|
19
19
|
def load(inp, j=0):
|
20
|
-
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate
|
20
|
+
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,default,gate in zip(*inp)]
|
21
21
|
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
22
22
|
|
23
23
|
def _store(m, i, v):
|
@@ -83,14 +83,11 @@ class PythonProgram:
|
|
83
83
|
elif uop is UOps.DEFINE_VAR:
|
84
84
|
ul[i] = [pvals.pop(0)] * warp_size
|
85
85
|
elif uop is UOps.SPECIAL:
|
86
|
-
if arg[
|
87
|
-
|
88
|
-
|
89
|
-
ul[i] = [x[2-arg[0]] for x in warp]
|
90
|
-
elif uop is UOps.CONST:
|
91
|
-
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
86
|
+
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
|
87
|
+
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
|
88
|
+
elif uop is UOps.CONST: ul[i] = [arg] * warp_size
|
92
89
|
elif uop is UOps.DEFINE_ACC:
|
93
|
-
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
90
|
+
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
94
91
|
elif uop is UOps.RANGE:
|
95
92
|
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
96
93
|
else:
|
@@ -100,20 +97,19 @@ class PythonProgram:
|
|
100
97
|
del ul[i]
|
101
98
|
i = loop_ends[i] + 1
|
102
99
|
continue
|
103
|
-
elif uop
|
104
|
-
|
100
|
+
elif uop is UOps.VECTORIZE: ul[i] = inp
|
101
|
+
elif uop in {UOps.CAST, UOps.BITCAST}:
|
102
|
+
assert dtp[0].fmt and dtype.fmt
|
103
|
+
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
104
|
+
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
105
105
|
else:
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
114
|
-
elif dtypes.is_float(dtype):
|
115
|
-
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
116
|
-
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
106
|
+
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
107
|
+
if dtypes.is_int(dtype):
|
108
|
+
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
109
|
+
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
110
|
+
elif dtypes.is_float(dtype):
|
111
|
+
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
112
|
+
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
117
113
|
elif uop is UOps.LOAD:
|
118
114
|
if isinstance(dtp[0], ImageDType):
|
119
115
|
assert dtype.count == 4
|
@@ -136,9 +132,9 @@ class PythonProgram:
|
|
136
132
|
elif uop is UOps.WMMA:
|
137
133
|
# here are the models for the WMMA instruction on the different hardware
|
138
134
|
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
139
|
-
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
|
140
|
-
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
|
141
|
-
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
|
135
|
+
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
|
136
|
+
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
|
137
|
+
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
|
142
138
|
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
143
139
|
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
144
140
|
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
@@ -152,13 +148,13 @@ class PythonProgram:
|
|
152
148
|
return out
|
153
149
|
|
154
150
|
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
155
|
-
if arg[
|
151
|
+
if arg[4] == "METAL":
|
156
152
|
# A (2 elements on 32 threads): row major
|
157
153
|
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
158
154
|
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
159
155
|
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
160
156
|
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
161
|
-
elif arg[
|
157
|
+
elif arg[4] == "AMD":
|
162
158
|
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
163
159
|
def a_elem(x, i, j, goff):
|
164
160
|
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
@@ -167,7 +163,7 @@ class PythonProgram:
|
|
167
163
|
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
|
168
164
|
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
169
165
|
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
170
|
-
elif arg[
|
166
|
+
elif arg[4] == "CUDA":
|
171
167
|
# A (8 elements on 32 threads)
|
172
168
|
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
173
169
|
# B (4 elements on 32 threads)
|
@@ -191,8 +187,8 @@ class PythonRenderer(Renderer):
|
|
191
187
|
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
192
188
|
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
193
189
|
|
194
|
-
def render(self, name:str, uops:
|
195
|
-
lops = [(u.op, u.dtype, [uops.
|
190
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
191
|
+
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
|
196
192
|
return base64.b64encode(pickle.dumps(lops)).decode()
|
197
193
|
|
198
194
|
class PythonCompiler(Compiler):
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import subprocess, hashlib, tempfile, ctypes, ctypes.util, re, pathlib
|
2
|
+
from typing import Callable
|
3
|
+
from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
|
4
|
+
import tinygrad.runtime.autogen.nvrtc as nvrtc
|
5
|
+
from tinygrad.device import Compiler, CompileError
|
6
|
+
|
7
|
+
PTX = getenv("PTX") # this shouldn't be here, in fact, it shouldn't exist
|
8
|
+
|
9
|
+
def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
10
|
+
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
|
11
|
+
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
|
12
|
+
|
13
|
+
def nvrtc_check(status, ctx=None):
|
14
|
+
if status != 0:
|
15
|
+
err_log = _get_bytes(ctx, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, lambda _: None).decode() if ctx else ""
|
16
|
+
raise CompileError(f"Nvrtc Error {status}, {ctypes.string_at(nvrtc.nvrtcGetErrorString(status)).decode()}\n{err_log}")
|
17
|
+
|
18
|
+
def jitlink_check(status, ctx=None):
|
19
|
+
if status != 0:
|
20
|
+
err_log = _get_bytes(ctx, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, lambda _: None).decode() if ctx else ""
|
21
|
+
raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}\n{err_log}")
|
22
|
+
|
23
|
+
def pretty_ptx(s):
|
24
|
+
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
25
|
+
s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
|
26
|
+
s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
|
27
|
+
s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
|
28
|
+
s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
|
29
|
+
s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
|
30
|
+
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
31
|
+
return s
|
32
|
+
|
33
|
+
def cuda_disassemble(lib, arch):
|
34
|
+
try:
|
35
|
+
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
36
|
+
with open(fn + ".ptx", "wb") as f: f.write(lib)
|
37
|
+
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
|
38
|
+
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
39
|
+
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
|
40
|
+
|
41
|
+
def nv_disassemble(lib):
|
42
|
+
try:
|
43
|
+
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
44
|
+
with open(fn + ".cubin", "wb") as f: f.write(lib)
|
45
|
+
print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
|
46
|
+
except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
|
47
|
+
|
48
|
+
class CUDACompiler(Compiler):
|
49
|
+
def __init__(self, arch:str, cache_key:str="cuda"):
|
50
|
+
self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
|
51
|
+
nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
|
52
|
+
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
|
53
|
+
super().__init__(f"compile_{cache_key}_{self.arch}")
|
54
|
+
def _compile_program(self, src:str, nvrtc_get_content:Callable, nvrtc_get_size:Callable) -> bytes:
|
55
|
+
nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
|
56
|
+
nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog)
|
57
|
+
data = _get_bytes(prog, nvrtc_get_content, nvrtc_get_size, nvrtc_check)
|
58
|
+
nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
|
59
|
+
return data
|
60
|
+
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
|
61
|
+
|
62
|
+
class NVCompiler(CUDACompiler):
|
63
|
+
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
|
64
|
+
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
|
65
|
+
|
66
|
+
class PTXCompiler(CUDACompiler):
|
67
|
+
def __init__(self, arch:str, cache_key="ptx"): super().__init__(arch, cache_key=cache_key)
|
68
|
+
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
|
69
|
+
|
70
|
+
class NVPTXCompiler(PTXCompiler):
|
71
|
+
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
|
72
|
+
def compile(self, src:str) -> bytes:
|
73
|
+
jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
|
74
|
+
jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
|
75
|
+
jitlink_check(nvrtc.nvJitLinkComplete(handle), handle)
|
76
|
+
data = _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
|
77
|
+
jitlink_check(nvrtc.nvJitLinkDestroy(handle))
|
78
|
+
return data
|
@@ -1,5 +1,6 @@
|
|
1
|
-
import ctypes
|
1
|
+
import ctypes, subprocess
|
2
2
|
import tinygrad.runtime.autogen.comgr as comgr
|
3
|
+
from tinygrad.device import Compiler, CompileError
|
3
4
|
|
4
5
|
def check(status):
|
5
6
|
if status != 0:
|
@@ -54,3 +55,16 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
|
54
55
|
for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
|
55
56
|
check(comgr.amd_comgr_destroy_action_info(action_info))
|
56
57
|
return ret
|
58
|
+
|
59
|
+
# this should probably be a method on the Compiler
|
60
|
+
def disasm(lib):
|
61
|
+
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
|
62
|
+
return '\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])
|
63
|
+
|
64
|
+
class AMDCompiler(Compiler):
|
65
|
+
def __init__(self, arch:str):
|
66
|
+
self.arch = arch
|
67
|
+
super().__init__(f"compile_hip_{self.arch}")
|
68
|
+
def compile(self, src:str) -> bytes:
|
69
|
+
try: return compile_hip(src, self.arch)
|
70
|
+
except RuntimeError as e: raise CompileError(e) from e
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Tuple, List, Any
|
3
|
+
from dataclasses import dataclass
|
4
|
+
import tinygrad.runtime.autogen.libc as libc
|
5
|
+
|
6
|
+
@dataclass(frozen=True)
|
7
|
+
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
|
8
|
+
|
9
|
+
def elf_loader(blob:bytes, force_section_align:int=1) -> Tuple[memoryview, List[ElfSection], Any]:
|
10
|
+
def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
|
11
|
+
|
12
|
+
header = libc.Elf64_Ehdr.from_buffer_copy(blob)
|
13
|
+
section_headers = (libc.Elf64_Shdr * header.e_shnum).from_buffer_copy(blob[header.e_shoff:])
|
14
|
+
sh_strtab = blob[(shstrst:=section_headers[header.e_shstrndx].sh_offset):shstrst+section_headers[header.e_shstrndx].sh_size]
|
15
|
+
sections = [ElfSection(_strtab(sh_strtab, sh.sh_name), sh, blob[sh.sh_offset:sh.sh_offset+sh.sh_size]) for sh in section_headers]
|
16
|
+
|
17
|
+
def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_entsize)).from_buffer_copy(sh.content)
|
18
|
+
rel = [(sh, sh.name[4:], _to_carray(sh, libc.Elf64_Rel)) for sh in sections if sh.header.sh_type == libc.SHT_REL]
|
19
|
+
rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
|
20
|
+
symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
|
21
|
+
progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
|
22
|
+
|
23
|
+
# Prealloc image for all fixed addresses.
|
24
|
+
image = bytearray(max([sh.header.sh_addr + sh.header.sh_size for sh in progbits if sh.header.sh_addr != 0] + [0]))
|
25
|
+
for sh in progbits:
|
26
|
+
if sh.header.sh_addr != 0: image[sh.header.sh_addr:sh.header.sh_addr+sh.header.sh_size] = sh.content
|
27
|
+
else:
|
28
|
+
image += b'\0' * (((align:=max(sh.header.sh_addralign, force_section_align)) - len(image) % align) % align) + sh.content
|
29
|
+
sh.header.sh_addr = len(image) - len(sh.content)
|
30
|
+
|
31
|
+
# Relocations
|
32
|
+
relocs = []
|
33
|
+
for sh, trgt_sh_name, c_rels in rel + rela:
|
34
|
+
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
|
35
|
+
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
|
36
|
+
relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
|
37
|
+
|
38
|
+
return memoryview(image), sections, relocs
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -3,18 +3,9 @@ from __future__ import annotations
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
5
5
|
from tinygrad.helpers import merge_dicts, getenv
|
6
|
-
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode,
|
6
|
+
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
7
7
|
from tinygrad.shape.view import View, strides_for_shape
|
8
8
|
|
9
|
-
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
10
|
-
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
11
|
-
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
|
12
|
-
vexpr: List[Node] = [valid] if valid is not None else []
|
13
|
-
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
14
|
-
if sh != 1 and st != 0: iexpr.append(idx*st)
|
15
|
-
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
16
|
-
return Node.sum(iexpr), Node.ands(vexpr)
|
17
|
-
|
18
9
|
@dataclass(frozen=True)
|
19
10
|
class ShapeTracker:
|
20
11
|
views: Tuple[View, ...]
|
@@ -32,7 +23,7 @@ class ShapeTracker:
|
|
32
23
|
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
33
24
|
|
34
25
|
@staticmethod
|
35
|
-
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
26
|
+
def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
36
27
|
|
37
28
|
@property
|
38
29
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
@@ -56,7 +47,7 @@ class ShapeTracker:
|
|
56
47
|
assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
|
57
48
|
return ret+1
|
58
49
|
|
59
|
-
def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views]
|
50
|
+
def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
|
60
51
|
|
61
52
|
@property
|
62
53
|
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
@@ -86,7 +77,7 @@ class ShapeTracker:
|
|
86
77
|
|
87
78
|
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
88
79
|
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
89
|
-
idx, valid =
|
80
|
+
idx, valid = self.views[-1].expr(idxs)
|
90
81
|
for view in reversed(self.views[0:-1]):
|
91
82
|
if valid.max == 0: return NumNode(-1), valid
|
92
83
|
view = view.minify()
|
@@ -94,7 +85,7 @@ class ShapeTracker:
|
|
94
85
|
for d in reversed(view.shape):
|
95
86
|
idxs.append((idx//acc)%d)
|
96
87
|
acc *= d
|
97
|
-
idx, valid =
|
88
|
+
idx, valid = view.expr(idxs[::-1], valid)
|
98
89
|
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
99
90
|
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
100
91
|
return idx, valid
|
tinygrad/shape/symbolic.py
CHANGED
@@ -43,6 +43,7 @@ class Node:
|
|
43
43
|
if b == 1: return self
|
44
44
|
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
45
45
|
def __rmul__(self, b:int): return self*b
|
46
|
+
def __lshift__(self, b:int): return self*2**b
|
46
47
|
|
47
48
|
# *** complex ops ***
|
48
49
|
|
@@ -74,7 +75,6 @@ class Node:
|
|
74
75
|
assert b > 0
|
75
76
|
if b == 1: return NumNode(0)
|
76
77
|
if isinstance(self.max, int) and isinstance(self.min, int):
|
77
|
-
if self.min >= 0 and self.max < b: return self
|
78
78
|
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
79
79
|
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
80
80
|
return create_node(ModNode(self, b))
|
@@ -231,7 +231,7 @@ class RedNode(Node):
|
|
231
231
|
def __init__(self, nodes:List[Node]):
|
232
232
|
self.nodes = nodes
|
233
233
|
self.min, self.max = self.get_bounds()
|
234
|
-
def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes]
|
234
|
+
def vars(self) -> Set[Variable]: return set().union(*[x.vars() for x in self.nodes])
|
235
235
|
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
236
236
|
|
237
237
|
class SumNode(RedNode):
|
@@ -291,11 +291,7 @@ class SumNode(RedNode):
|
|
291
291
|
class AndNode(RedNode):
|
292
292
|
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
293
293
|
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
294
|
-
|
295
|
-
for node in self.nodes:
|
296
|
-
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
297
|
-
subed.append(sub)
|
298
|
-
return Node.ands(subed)
|
294
|
+
return Node.ands([node.substitute(var_vals) for node in self.nodes])
|
299
295
|
|
300
296
|
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
301
297
|
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
@@ -324,4 +320,4 @@ render_python: Dict[Type, Callable[..., str]] = {
|
|
324
320
|
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
325
321
|
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
326
322
|
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
327
|
-
}
|
323
|
+
}
|
tinygrad/shape/view.py
CHANGED
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Tuple, List, Optional, Dict, Set, cast
|
5
5
|
from tinygrad.helpers import prod, all_int, argsort
|
6
|
-
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
6
|
+
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer, create_lt_node, create_ge_node
|
7
7
|
|
8
8
|
@functools.lru_cache(maxsize=None)
|
9
9
|
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
@@ -35,14 +35,17 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
|
|
35
35
|
return tuple(ret)
|
36
36
|
|
37
37
|
@functools.lru_cache(maxsize=None)
|
38
|
-
def _reshape_mask(
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
|
39
|
+
-> Optional[Tuple[Tuple[sint, sint], ...]]:
|
40
|
+
"""Returns the new mask if reshape is possible, and None if not possible."""
|
41
|
+
if _mask is None: return tuple((0, s) for s in new_shape)
|
42
|
+
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
|
43
|
+
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
|
42
44
|
|
43
|
-
|
45
|
+
new_mask: List[Tuple[int, int]] = []
|
46
|
+
# _mask is all int here
|
47
|
+
r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
|
44
48
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
45
|
-
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
46
49
|
|
47
50
|
while len(new_mask) < len(new_shape):
|
48
51
|
(l, r), next_stride = mask, new_dim * curr_stride
|
@@ -51,24 +54,23 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
51
54
|
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
52
55
|
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
53
56
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
54
|
-
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
55
57
|
|
56
58
|
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
57
59
|
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
58
|
-
or old_dim % next_stride != 0): return
|
60
|
+
or old_dim % next_stride != 0): return None
|
59
61
|
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
60
62
|
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
61
63
|
|
62
64
|
else:
|
63
65
|
next_mask = next(r_masks, (0, 1))
|
64
66
|
# combine if the mask can unfold continuously
|
65
|
-
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return
|
67
|
+
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
|
66
68
|
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
67
69
|
|
68
70
|
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
69
|
-
if mask != (0, 1): return ((0, 0),) * len(new_shape)
|
71
|
+
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
|
70
72
|
|
71
|
-
return tuple(reversed(new_mask))
|
73
|
+
return tuple(reversed(new_mask))
|
72
74
|
|
73
75
|
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
74
76
|
strides = strides_for_shape(shape)
|
@@ -97,6 +99,7 @@ class View:
|
|
97
99
|
@staticmethod
|
98
100
|
@functools.lru_cache(maxsize=None)
|
99
101
|
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
102
|
+
if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
100
103
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
101
104
|
# canonicalize 0 in shape
|
102
105
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
@@ -120,13 +123,13 @@ class View:
|
|
120
123
|
|
121
124
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
122
125
|
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
123
|
-
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()
|
126
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
124
127
|
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
128
|
+
def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
129
|
+
new_shape = tuple(map(substitute, self.shape))
|
130
|
+
new_strides = tuple(map(substitute, self.strides))
|
131
|
+
new_offset = substitute(self.offset)
|
132
|
+
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
130
133
|
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
131
134
|
|
132
135
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -301,11 +304,20 @@ class View:
|
|
301
304
|
if acc != merged_dim: break
|
302
305
|
else:
|
303
306
|
strides += [0,] * (len(new_shape) - len(strides))
|
304
|
-
new_mask
|
305
|
-
if not
|
306
|
-
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask)
|
307
|
+
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
|
308
|
+
if new_mask is not None:
|
309
|
+
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
|
307
310
|
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
308
|
-
(sum(m[0] * s for m,s in zip(new_mask, new_strides))
|
311
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
309
312
|
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
310
313
|
|
311
314
|
return None
|
315
|
+
|
316
|
+
def expr(self, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
317
|
+
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
318
|
+
iexpr: List[Node] = [NumNode(self.offset) if isinstance(self.offset, int) else self.offset]
|
319
|
+
vexpr: List[Node] = [valid] if valid is not None else []
|
320
|
+
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
321
|
+
if sh != 1 and st != 0: iexpr.append(idx*st)
|
322
|
+
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
323
|
+
return Node.sum(iexpr), Node.ands(vexpr)
|