tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
7
7
  from tinygrad.dtype import DType, dtypes, ImageDType
8
8
  from tinygrad.helpers import all_same, getenv, flatten
9
9
  from tinygrad.device import Compiled, Compiler, Allocator
10
- from tinygrad.codegen.uops import UOpGraph, UOps
10
+ from tinygrad.codegen.uops import UOps, UOp
11
11
  from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
12
12
  from tinygrad.renderer import Renderer
13
13
  from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
@@ -17,7 +17,7 @@ def _load(m, i):
17
17
  return m[i]
18
18
 
19
19
  def load(inp, j=0):
20
- if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
20
+ if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,default,gate in zip(*inp)]
21
21
  return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
22
22
 
23
23
  def _store(m, i, v):
@@ -83,14 +83,11 @@ class PythonProgram:
83
83
  elif uop is UOps.DEFINE_VAR:
84
84
  ul[i] = [pvals.pop(0)] * warp_size
85
85
  elif uop is UOps.SPECIAL:
86
- if arg[1][0] == 'g':
87
- ul[i] = [idxs[2-arg[0]]] * warp_size
88
- elif arg[1][0] == 'l':
89
- ul[i] = [x[2-arg[0]] for x in warp]
90
- elif uop is UOps.CONST:
91
- ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
86
+ if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
87
+ elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
88
+ elif uop is UOps.CONST: ul[i] = [arg] * warp_size
92
89
  elif uop is UOps.DEFINE_ACC:
93
- ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
90
+ ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
94
91
  elif uop is UOps.RANGE:
95
92
  if i not in ul: ul[i] = [inp[0][0]] * warp_size
96
93
  else:
@@ -100,20 +97,19 @@ class PythonProgram:
100
97
  del ul[i]
101
98
  i = loop_ends[i] + 1
102
99
  continue
103
- elif uop in (UOps.CAST, UOps.BITCAST):
104
- if dtype.count > 1: ul[i] = inp
100
+ elif uop is UOps.VECTORIZE: ul[i] = inp
101
+ elif uop in {UOps.CAST, UOps.BITCAST}:
102
+ assert dtp[0].fmt and dtype.fmt
103
+ pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
104
+ if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
105
105
  else:
106
- assert dtp[0].fmt and dtype.fmt
107
- pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
108
- if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
109
- else:
110
- casted = [dtypes.as_const(x, dtype) for x in inp[0]]
111
- if dtypes.is_int(dtype):
112
- overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
113
- casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
114
- elif dtypes.is_float(dtype):
115
- casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
116
- ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
106
+ casted = [dtypes.as_const(x, dtype) for x in inp[0]]
107
+ if dtypes.is_int(dtype):
108
+ overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
109
+ casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
110
+ elif dtypes.is_float(dtype):
111
+ casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
112
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
117
113
  elif uop is UOps.LOAD:
118
114
  if isinstance(dtp[0], ImageDType):
119
115
  assert dtype.count == 4
@@ -136,9 +132,9 @@ class PythonProgram:
136
132
  elif uop is UOps.WMMA:
137
133
  # here are the models for the WMMA instruction on the different hardware
138
134
  def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
139
- assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
140
- assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
141
- assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
135
+ assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
136
+ assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
137
+ assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
142
138
  assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
143
139
  assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
144
140
  assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
@@ -152,13 +148,13 @@ class PythonProgram:
152
148
  return out
153
149
 
154
150
  # TODO: refactor these to a shared TensorCoreLayout in kernel.py
155
- if arg[5] == "METAL":
151
+ if arg[4] == "METAL":
156
152
  # A (2 elements on 32 threads): row major
157
153
  def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
158
154
  # (i, j), C, D (2 elements on 32 threads): row major same as A/B
159
155
  def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
160
156
  ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
161
- elif arg[5] == "AMD":
157
+ elif arg[4] == "AMD":
162
158
  # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
163
159
  def a_elem(x, i, j, goff):
164
160
  assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@@ -167,7 +163,7 @@ class PythonProgram:
167
163
  def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
168
164
  def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
169
165
  ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
170
- elif arg[5] == "CUDA":
166
+ elif arg[4] == "CUDA":
171
167
  # A (8 elements on 32 threads)
172
168
  def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
173
169
  # B (4 elements on 32 threads)
@@ -191,8 +187,8 @@ class PythonRenderer(Renderer):
191
187
  if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
192
188
  if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
193
189
 
194
- def render(self, name:str, uops:UOpGraph) -> str:
195
- lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
190
+ def render(self, name:str, uops:List[UOp]) -> str:
191
+ lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
196
192
  return base64.b64encode(pickle.dumps(lops)).decode()
197
193
 
198
194
  class PythonCompiler(Compiler):
@@ -0,0 +1,78 @@
1
+ import subprocess, hashlib, tempfile, ctypes, ctypes.util, re, pathlib
2
+ from typing import Callable
3
+ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
4
+ import tinygrad.runtime.autogen.nvrtc as nvrtc
5
+ from tinygrad.device import Compiler, CompileError
6
+
7
+ PTX = getenv("PTX") # this shouldn't be here, in fact, it shouldn't exist
8
+
9
+ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
10
+ sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
11
+ return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
12
+
13
+ def nvrtc_check(status, ctx=None):
14
+ if status != 0:
15
+ err_log = _get_bytes(ctx, nvrtc.nvrtcGetProgramLog, nvrtc.nvrtcGetProgramLogSize, lambda _: None).decode() if ctx else ""
16
+ raise CompileError(f"Nvrtc Error {status}, {ctypes.string_at(nvrtc.nvrtcGetErrorString(status)).decode()}\n{err_log}")
17
+
18
+ def jitlink_check(status, ctx=None):
19
+ if status != 0:
20
+ err_log = _get_bytes(ctx, nvrtc.nvJitLinkGetErrorLog, nvrtc.nvJitLinkGetErrorLogSize, lambda _: None).decode() if ctx else ""
21
+ raise CompileError(f"NvJitLink Error {status}, {nvrtc.nvJitLinkResult__enumvalues.get(status, 'Unknown')}\n{err_log}")
22
+
23
+ def pretty_ptx(s):
24
+ # all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
25
+ s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
26
+ s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
27
+ s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
28
+ s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
29
+ s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
30
+ s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
31
+ return s
32
+
33
+ def cuda_disassemble(lib, arch):
34
+ try:
35
+ fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
36
+ with open(fn + ".ptx", "wb") as f: f.write(lib)
37
+ subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
38
+ print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
39
+ except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
40
+
41
+ def nv_disassemble(lib):
42
+ try:
43
+ fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
44
+ with open(fn + ".cubin", "wb") as f: f.write(lib)
45
+ print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
46
+ except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
47
+
48
+ class CUDACompiler(Compiler):
49
+ def __init__(self, arch:str, cache_key:str="cuda"):
50
+ self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
51
+ nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
52
+ if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
53
+ super().__init__(f"compile_{cache_key}_{self.arch}")
54
+ def _compile_program(self, src:str, nvrtc_get_content:Callable, nvrtc_get_size:Callable) -> bytes:
55
+ nvrtc_check(nvrtc.nvrtcCreateProgram(ctypes.byref(prog := nvrtc.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
56
+ nvrtc_check(nvrtc.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])), prog)
57
+ data = _get_bytes(prog, nvrtc_get_content, nvrtc_get_size, nvrtc_check)
58
+ nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
59
+ return data
60
+ def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
61
+
62
+ class NVCompiler(CUDACompiler):
63
+ def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
64
+ def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetCUBIN, nvrtc.nvrtcGetCUBINSize)
65
+
66
+ class PTXCompiler(CUDACompiler):
67
+ def __init__(self, arch:str, cache_key="ptx"): super().__init__(arch, cache_key=cache_key)
68
+ def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
69
+
70
+ class NVPTXCompiler(PTXCompiler):
71
+ def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
72
+ def compile(self, src:str) -> bytes:
73
+ jitlink_check(nvrtc.nvJitLinkCreate(handle := nvrtc.nvJitLinkHandle(), 1, to_char_p_p([f'-arch={self.arch}'.encode()])), handle)
74
+ jitlink_check(nvrtc.nvJitLinkAddData(handle, nvrtc.NVJITLINK_INPUT_PTX, ptxsrc:=super().compile(src), len(ptxsrc), "<null>".encode()), handle)
75
+ jitlink_check(nvrtc.nvJitLinkComplete(handle), handle)
76
+ data = _get_bytes(handle, nvrtc.nvJitLinkGetLinkedCubin, nvrtc.nvJitLinkGetLinkedCubinSize, jitlink_check)
77
+ jitlink_check(nvrtc.nvJitLinkDestroy(handle))
78
+ return data
@@ -1,5 +1,6 @@
1
- import ctypes
1
+ import ctypes, subprocess
2
2
  import tinygrad.runtime.autogen.comgr as comgr
3
+ from tinygrad.device import Compiler, CompileError
3
4
 
4
5
  def check(status):
5
6
  if status != 0:
@@ -54,3 +55,16 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
54
55
  for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
55
56
  check(comgr.amd_comgr_destroy_action_info(action_info))
56
57
  return ret
58
+
59
+ # this should probably be a method on the Compiler
60
+ def disasm(lib):
61
+ asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
62
+ return '\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])
63
+
64
+ class AMDCompiler(Compiler):
65
+ def __init__(self, arch:str):
66
+ self.arch = arch
67
+ super().__init__(f"compile_hip_{self.arch}")
68
+ def compile(self, src:str) -> bytes:
69
+ try: return compile_hip(src, self.arch)
70
+ except RuntimeError as e: raise CompileError(e) from e
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+ from typing import Tuple, List, Any
3
+ from dataclasses import dataclass
4
+ import tinygrad.runtime.autogen.libc as libc
5
+
6
+ @dataclass(frozen=True)
7
+ class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
8
+
9
+ def elf_loader(blob:bytes, force_section_align:int=1) -> Tuple[memoryview, List[ElfSection], Any]:
10
+ def _strtab(blob: bytes, idx: int) -> str: return blob[idx:blob.find(b'\x00', idx)].decode('utf-8')
11
+
12
+ header = libc.Elf64_Ehdr.from_buffer_copy(blob)
13
+ section_headers = (libc.Elf64_Shdr * header.e_shnum).from_buffer_copy(blob[header.e_shoff:])
14
+ sh_strtab = blob[(shstrst:=section_headers[header.e_shstrndx].sh_offset):shstrst+section_headers[header.e_shstrndx].sh_size]
15
+ sections = [ElfSection(_strtab(sh_strtab, sh.sh_name), sh, blob[sh.sh_offset:sh.sh_offset+sh.sh_size]) for sh in section_headers]
16
+
17
+ def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_entsize)).from_buffer_copy(sh.content)
18
+ rel = [(sh, sh.name[4:], _to_carray(sh, libc.Elf64_Rel)) for sh in sections if sh.header.sh_type == libc.SHT_REL]
19
+ rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
20
+ symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
21
+ progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
22
+
23
+ # Prealloc image for all fixed addresses.
24
+ image = bytearray(max([sh.header.sh_addr + sh.header.sh_size for sh in progbits if sh.header.sh_addr != 0] + [0]))
25
+ for sh in progbits:
26
+ if sh.header.sh_addr != 0: image[sh.header.sh_addr:sh.header.sh_addr+sh.header.sh_size] = sh.content
27
+ else:
28
+ image += b'\0' * (((align:=max(sh.header.sh_addralign, force_section_align)) - len(image) % align) % align) + sh.content
29
+ sh.header.sh_addr = len(image) - len(sh.content)
30
+
31
+ # Relocations
32
+ relocs = []
33
+ for sh, trgt_sh_name, c_rels in rel + rela:
34
+ target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
35
+ rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
36
+ relocs += [(target_image_off + roff, sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
37
+
38
+ return memoryview(image), sections, relocs
@@ -3,18 +3,9 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
5
5
  from tinygrad.helpers import merge_dicts, getenv
6
- from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
6
+ from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
7
7
  from tinygrad.shape.view import View, strides_for_shape
8
8
 
9
- def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
10
- assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
11
- iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
12
- vexpr: List[Node] = [valid] if valid is not None else []
13
- for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
14
- if sh != 1 and st != 0: iexpr.append(idx*st)
15
- if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
16
- return Node.sum(iexpr), Node.ands(vexpr)
17
-
18
9
  @dataclass(frozen=True)
19
10
  class ShapeTracker:
20
11
  views: Tuple[View, ...]
@@ -32,7 +23,7 @@ class ShapeTracker:
32
23
  return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
33
24
 
34
25
  @staticmethod
35
- def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
26
+ def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
36
27
 
37
28
  @property
38
29
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
@@ -56,7 +47,7 @@ class ShapeTracker:
56
47
  assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
57
48
  return ret+1
58
49
 
59
- def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
50
+ def vars(self) -> Set[Variable]: return set().union(*[v.vars() for v in self.views])
60
51
 
61
52
  @property
62
53
  def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
@@ -86,7 +77,7 @@ class ShapeTracker:
86
77
 
87
78
  def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
88
79
  idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
89
- idx, valid = _expr_view(self.views[-1], idxs)
80
+ idx, valid = self.views[-1].expr(idxs)
90
81
  for view in reversed(self.views[0:-1]):
91
82
  if valid.max == 0: return NumNode(-1), valid
92
83
  view = view.minify()
@@ -94,7 +85,7 @@ class ShapeTracker:
94
85
  for d in reversed(view.shape):
95
86
  idxs.append((idx//acc)%d)
96
87
  acc *= d
97
- idx, valid = _expr_view(view, idxs[::-1], valid)
88
+ idx, valid = view.expr(idxs[::-1], valid)
98
89
  assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
99
90
  assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
100
91
  return idx, valid
@@ -43,6 +43,7 @@ class Node:
43
43
  if b == 1: return self
44
44
  return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
45
45
  def __rmul__(self, b:int): return self*b
46
+ def __lshift__(self, b:int): return self*2**b
46
47
 
47
48
  # *** complex ops ***
48
49
 
@@ -74,7 +75,6 @@ class Node:
74
75
  assert b > 0
75
76
  if b == 1: return NumNode(0)
76
77
  if isinstance(self.max, int) and isinstance(self.min, int):
77
- if self.min >= 0 and self.max < b: return self
78
78
  if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
79
79
  if self.min < 0: return (self - ((self.min//b)*b)) % b
80
80
  return create_node(ModNode(self, b))
@@ -231,7 +231,7 @@ class RedNode(Node):
231
231
  def __init__(self, nodes:List[Node]):
232
232
  self.nodes = nodes
233
233
  self.min, self.max = self.get_bounds()
234
- def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
234
+ def vars(self) -> Set[Variable]: return set().union(*[x.vars() for x in self.nodes])
235
235
  def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
236
236
 
237
237
  class SumNode(RedNode):
@@ -291,11 +291,7 @@ class SumNode(RedNode):
291
291
  class AndNode(RedNode):
292
292
  def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
293
293
  def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
294
- subed = []
295
- for node in self.nodes:
296
- if not (sub:=node.substitute(var_vals)): return NumNode(0)
297
- subed.append(sub)
298
- return Node.ands(subed)
294
+ return Node.ands([node.substitute(var_vals) for node in self.nodes])
299
295
 
300
296
  def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
301
297
  def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
@@ -324,4 +320,4 @@ render_python: Dict[Type, Callable[..., str]] = {
324
320
  LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
325
321
  SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
326
322
  AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
327
- }
323
+ }
tinygrad/shape/view.py CHANGED
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
3
3
  from dataclasses import dataclass
4
4
  from typing import Tuple, List, Optional, Dict, Set, cast
5
5
  from tinygrad.helpers import prod, all_int, argsort
6
- from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
6
+ from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer, create_lt_node, create_ge_node
7
7
 
8
8
  @functools.lru_cache(maxsize=None)
9
9
  def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -35,14 +35,17 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
35
35
  return tuple(ret)
36
36
 
37
37
  @functools.lru_cache(maxsize=None)
38
- def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]:
39
- if view.mask is None: return view.mask, False
40
- if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True
41
- new_mask: List[Tuple[int, int]] = []
38
+ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
39
+ -> Optional[Tuple[Tuple[sint, sint], ...]]:
40
+ """Returns the new mask if reshape is possible, and None if not possible."""
41
+ if _mask is None: return tuple((0, s) for s in new_shape)
42
+ if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
43
+ if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
42
44
 
43
- r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
45
+ new_mask: List[Tuple[int, int]] = []
46
+ # _mask is all int here
47
+ r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
44
48
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
45
- if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
46
49
 
47
50
  while len(new_mask) < len(new_shape):
48
51
  (l, r), next_stride = mask, new_dim * curr_stride
@@ -51,24 +54,23 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
51
54
  if old_dim == next_stride: # simply copy the mask and get next batch for merging
52
55
  new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
53
56
  curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
54
- if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
55
57
 
56
58
  else: # mask can only be splitted if reshape doesn't cut across the mask.
57
59
  if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
58
- or old_dim % next_stride != 0): return view.mask, True
60
+ or old_dim % next_stride != 0): return None
59
61
  new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
60
62
  curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
61
63
 
62
64
  else:
63
65
  next_mask = next(r_masks, (0, 1))
64
66
  # combine if the mask can unfold continuously
65
- if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True
67
+ if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
66
68
  mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
67
69
 
68
70
  for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
69
- if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask
71
+ if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
70
72
 
71
- return tuple(reversed(new_mask)), False
73
+ return tuple(reversed(new_mask))
72
74
 
73
75
  def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
74
76
  strides = strides_for_shape(shape)
@@ -97,6 +99,7 @@ class View:
97
99
  @staticmethod
98
100
  @functools.lru_cache(maxsize=None)
99
101
  def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
102
+ if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
100
103
  strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
101
104
  # canonicalize 0 in shape
102
105
  if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
@@ -120,13 +123,13 @@ class View:
120
123
 
121
124
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
122
125
  def unbind(self) -> Tuple[View, Dict[Variable, int]]:
123
- var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
126
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
124
127
  unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
125
- new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
126
- new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
127
- new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
128
- new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
129
- b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
128
+ def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
129
+ new_shape = tuple(map(substitute, self.shape))
130
+ new_strides = tuple(map(substitute, self.strides))
131
+ new_offset = substitute(self.offset)
132
+ new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
130
133
  return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
131
134
 
132
135
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -301,11 +304,20 @@ class View:
301
304
  if acc != merged_dim: break
302
305
  else:
303
306
  strides += [0,] * (len(new_shape) - len(strides))
304
- new_mask, extra = _reshape_mask(self, new_shape)
305
- if not extra:
306
- new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
307
+ new_mask = _reshape_mask(self.mask, self.shape, new_shape)
308
+ if new_mask is not None:
309
+ new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
307
310
  extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
308
- (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
311
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)))
309
312
  return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
310
313
 
311
314
  return None
315
+
316
+ def expr(self, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
317
+ assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
318
+ iexpr: List[Node] = [NumNode(self.offset) if isinstance(self.offset, int) else self.offset]
319
+ vexpr: List[Node] = [valid] if valid is not None else []
320
+ for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
321
+ if sh != 1 and st != 0: iexpr.append(idx*st)
322
+ if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
323
+ return Node.sum(iexpr), Node.ands(vexpr)