tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
import functools, importlib, re, urllib
|
2
|
+
from collections import defaultdict
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.helpers import getbits, round_up, fetch
|
5
|
+
from tinygrad.runtime.autogen import pci
|
6
|
+
from tinygrad.runtime.support.usb import ASM24Controller
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class AMDReg:
|
10
|
+
name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:dict[int, tuple[int, ...]] # noqa: E702
|
11
|
+
def __post_init__(self): self.addr:dict[int, int] = { inst: bases[self.segment] + self.offset for inst, bases in self.bases.items() }
|
12
|
+
|
13
|
+
def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
|
14
|
+
def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
|
15
|
+
|
16
|
+
def fields_mask(self, *names) -> int:
|
17
|
+
return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0)
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class AMDIP:
|
21
|
+
name:str; version:tuple[int, ...]; bases:dict[int, tuple[int, ...]] # noqa: E702
|
22
|
+
def __post_init__(self): self.version = fixup_ip_version(self.name, self.version)[0]
|
23
|
+
|
24
|
+
@functools.cached_property
|
25
|
+
def regs(self): return import_asic_regs(self.name, self.version, cls=functools.partial(AMDReg, bases=self.bases))
|
26
|
+
|
27
|
+
def __getattr__(self, name:str):
|
28
|
+
if name in self.regs: return self.regs[name]
|
29
|
+
|
30
|
+
# NOTE: gfx10 gc registers always start with mm, no reg prefix
|
31
|
+
return self.regs[name.replace('reg', 'mm')]
|
32
|
+
|
33
|
+
def fixup_ip_version(ip:str, version:tuple[int, ...]) -> list[tuple[int, ...]]:
|
34
|
+
# override versions
|
35
|
+
def _apply_ovrd(ovrd:dict[tuple[int, ...], tuple[int, ...]]) -> tuple[int, ...]:
|
36
|
+
for ver, ovrd_ver in ovrd.items():
|
37
|
+
if version[:len(ver)] == ver: return ovrd_ver
|
38
|
+
return version
|
39
|
+
|
40
|
+
if ip in ['nbio', 'nbif']: version = _apply_ovrd({(3,3): (2,3,0)})
|
41
|
+
elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2)})
|
42
|
+
elif ip in ['gc']: version = _apply_ovrd({(9,5,0): (9,4,3)})
|
43
|
+
|
44
|
+
return [version, version[:2], version[:2]+(0,), version[:1]+(0, 0)]
|
45
|
+
|
46
|
+
def header_download(file, name=None, subdir="defines") -> str:
|
47
|
+
url = "https://gitlab.com/linux-kernel/linux-next/-/raw/cf6d949a409e09539477d32dbe7c954e4852e744/drivers/gpu/drm/amd"
|
48
|
+
return fetch(f"{url}/{file}", name=name, subdir=subdir).read_text()
|
49
|
+
|
50
|
+
def import_header(path:str):
|
51
|
+
t = re.sub(r'//.*|/\*.*?\*/','', header_download(path, subdir="defines"), flags=re.S)
|
52
|
+
return {k:int(v,0) for k,v in re.findall(r'\b([A-Za-z_]\w*)\s*=\s*(0x[0-9A-Fa-f]+|\d+)', t)}
|
53
|
+
|
54
|
+
def import_module(name:str, version:tuple[int, ...], version_prefix:str=""):
|
55
|
+
for ver in fixup_ip_version(name, version):
|
56
|
+
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, ver))}")
|
57
|
+
except ImportError: pass
|
58
|
+
raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}")
|
59
|
+
|
60
|
+
def import_soc(ip): return type("SOC", (object,), import_header(f"include/{({9: 'vega10', 10: 'navi10', 11: 'soc21', 12: 'soc24'}[ip[0]])}_enum.h"))
|
61
|
+
|
62
|
+
def import_asic_regs(prefix:str, version:tuple[int, ...], cls=AMDReg) -> dict[str, AMDReg]:
|
63
|
+
def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
|
64
|
+
def _extract_regs(txt):
|
65
|
+
return {m.group(1): int(m.group(2), 0) for line in txt.splitlines() if (m:=re.match(r'#define\s+(\S+)\s+(0x[\da-fA-F]+|\d+)', line))}
|
66
|
+
def _download_file(ver, suff) -> str:
|
67
|
+
dir_prefix = {"osssys": "oss"}.get(prefix, prefix)
|
68
|
+
fetch_name, file_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h", f"{prefix}_{'_'.join(map(str, version))}_{suff}.h"
|
69
|
+
return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=file_name, subdir="asic_regs")
|
70
|
+
|
71
|
+
for ver in fixup_ip_version(prefix, version):
|
72
|
+
try: offs, sh_masks = _extract_regs(_download_file(ver, "offset")), _extract_regs(_download_file(ver, "sh_mask"))
|
73
|
+
except urllib.error.HTTPError as e:
|
74
|
+
if e.code == 404: continue
|
75
|
+
raise
|
76
|
+
|
77
|
+
offsets = {k:v for k,v in offs.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
|
78
|
+
bases = {k[:-len('_BASE_IDX')]:v for k,v in offs.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
|
79
|
+
|
80
|
+
fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
81
|
+
for field_name, field_mask in sh_masks.items():
|
82
|
+
if not ('__' in field_name and field_name.endswith('_MASK')): continue
|
83
|
+
reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
|
84
|
+
fields[reg_name][reg_field_name.lower()] = ((field_mask & -field_mask).bit_length()-1, field_mask.bit_length()-1)
|
85
|
+
|
86
|
+
# NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
|
87
|
+
return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}
|
88
|
+
raise ImportError(f"Failed to load ASIC registers for {prefix.upper()} {'.'.join(map(str, version))}")
|
89
|
+
|
90
|
+
def setup_pci_bars(usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]:
|
91
|
+
for bus in range(gpu_bus):
|
92
|
+
# All 3 values must be written at the same time.
|
93
|
+
buses = (0 << 0) | ((bus+1) << 8) | ((gpu_bus) << 16)
|
94
|
+
usb.pcie_cfg_req(pci.PCI_PRIMARY_BUS, bus=bus, dev=0, fn=0, value=buses, size=4)
|
95
|
+
|
96
|
+
usb.pcie_cfg_req(pci.PCI_MEMORY_BASE, bus=bus, dev=0, fn=0, value=(mem_base>>16) & 0xffff, size=2)
|
97
|
+
usb.pcie_cfg_req(pci.PCI_MEMORY_LIMIT, bus=bus, dev=0, fn=0, value=0xffff, size=2)
|
98
|
+
usb.pcie_cfg_req(pci.PCI_PREF_MEMORY_BASE, bus=bus, dev=0, fn=0, value=(pref_mem_base>>16) & 0xffff, size=2)
|
99
|
+
usb.pcie_cfg_req(pci.PCI_PREF_MEMORY_LIMIT, bus=bus, dev=0, fn=0, value=0xffff, size=2)
|
100
|
+
usb.pcie_cfg_req(pci.PCI_PREF_BASE_UPPER32, bus=bus, dev=0, fn=0, value=pref_mem_base >> 32, size=4)
|
101
|
+
usb.pcie_cfg_req(pci.PCI_PREF_LIMIT_UPPER32, bus=bus, dev=0, fn=0, value=0xffffffff, size=4)
|
102
|
+
|
103
|
+
usb.pcie_cfg_req(pci.PCI_COMMAND, bus=bus, dev=0, fn=0, value=pci.PCI_COMMAND_IO | pci.PCI_COMMAND_MEMORY | pci.PCI_COMMAND_MASTER, size=1)
|
104
|
+
|
105
|
+
# resize bar 0
|
106
|
+
cap_ptr = 0x100
|
107
|
+
while cap_ptr:
|
108
|
+
if pci.PCI_EXT_CAP_ID(hdr:=usb.pcie_cfg_req(cap_ptr, bus=gpu_bus, dev=0, fn=0, size=4)) == pci.PCI_EXT_CAP_ID_REBAR:
|
109
|
+
cap = usb.pcie_cfg_req(cap_ptr + 0x04, bus=gpu_bus, dev=0, fn=0, size=4)
|
110
|
+
new_ctrl = (usb.pcie_cfg_req(cap_ptr + 0x08, bus=gpu_bus, dev=0, fn=0, size=4) & ~0x1F00) | ((int(cap >> 4).bit_length() - 1) << 8)
|
111
|
+
usb.pcie_cfg_req(cap_ptr + 0x08, bus=gpu_bus, dev=0, fn=0, value=new_ctrl, size=4)
|
112
|
+
|
113
|
+
cap_ptr = pci.PCI_EXT_CAP_NEXT(hdr)
|
114
|
+
|
115
|
+
mem_space_addr, bar_off, bars = [mem_base, pref_mem_base], 0, {}
|
116
|
+
while bar_off < 24:
|
117
|
+
cfg = usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, size=4)
|
118
|
+
bar_mem, bar_64 = bool(cfg & pci.PCI_BASE_ADDRESS_MEM_PREFETCH), cfg & pci.PCI_BASE_ADDRESS_MEM_TYPE_64
|
119
|
+
|
120
|
+
if (cfg & pci.PCI_BASE_ADDRESS_SPACE) == pci.PCI_BASE_ADDRESS_SPACE_MEMORY:
|
121
|
+
usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, value=0xffffffff, size=4)
|
122
|
+
lo = (usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, size=4) & 0xfffffff0)
|
123
|
+
|
124
|
+
if bar_64: usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, value=0xffffffff, size=4)
|
125
|
+
hi = (usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, size=4) if bar_64 else 0)
|
126
|
+
|
127
|
+
bar_size = ((~(((hi << 32) | lo) & ~0xf)) + 1) & (0xffffffffffffffff if bar_64 else 0xffffffff)
|
128
|
+
|
129
|
+
usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, value=mem_space_addr[bar_mem] & 0xffffffff, size=4)
|
130
|
+
if bar_64: usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, value=mem_space_addr[bar_mem] >> 32, size=4)
|
131
|
+
|
132
|
+
bars[bar_off // 4] = (mem_space_addr[bar_mem], bar_size)
|
133
|
+
mem_space_addr[bar_mem] += round_up(bar_size, 2 << 20)
|
134
|
+
|
135
|
+
bar_off += 8 if bar_64 else 4
|
136
|
+
|
137
|
+
usb.pcie_cfg_req(pci.PCI_COMMAND, bus=gpu_bus, dev=0, fn=0, value=pci.PCI_COMMAND_IO | pci.PCI_COMMAND_MEMORY | pci.PCI_COMMAND_MASTER, size=1)
|
138
|
+
return bars
|
@@ -1,6 +1,20 @@
|
|
1
1
|
import ctypes, subprocess
|
2
2
|
import tinygrad.runtime.autogen.comgr as comgr
|
3
|
+
assert comgr.AMD_COMGR_LANGUAGE_HIP == 4
|
4
|
+
try:
|
5
|
+
comgr.amd_comgr_get_version(ctypes.byref(major:=ctypes.c_uint64()), ctypes.byref(minor:=ctypes.c_uint64()))
|
6
|
+
if major.value >= 3:
|
7
|
+
# in comgr 3 the values of enums in headers were changed: https://github.com/ROCm/llvm-project/issues/272
|
8
|
+
import tinygrad.runtime.autogen.comgr_3 as comgr # type: ignore[no-redef]
|
9
|
+
assert comgr.AMD_COMGR_LANGUAGE_HIP == 3
|
10
|
+
except AttributeError: pass # ignore if ROCm isn't installed
|
3
11
|
from tinygrad.device import Compiler, CompileError
|
12
|
+
from tinygrad.runtime.ops_llvm import LLVMCompiler
|
13
|
+
from tinygrad.helpers import OSX, to_char_p_p
|
14
|
+
|
15
|
+
def amdgpu_disassemble(lib:bytes):
|
16
|
+
asm = subprocess.check_output(["llvm-objdump" if OSX else "/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
|
17
|
+
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
4
18
|
|
5
19
|
def check(status):
|
6
20
|
if status != 0:
|
@@ -14,6 +28,12 @@ def _get_comgr_data(data_set, data_type):
|
|
14
28
|
check(comgr.amd_comgr_release_data(data_exec))
|
15
29
|
return bytes(dat)
|
16
30
|
|
31
|
+
# amd_comgr_action_info_set_options was deprecated
|
32
|
+
def set_options(action_info, options:bytes):
|
33
|
+
# TODO: this type should be correct in the autogen stub
|
34
|
+
comgr.amd_comgr_action_info_set_option_list.argtypes = [comgr.amd_comgr_action_info_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), comgr.size_t]
|
35
|
+
return comgr.amd_comgr_action_info_set_option_list(action_info, to_char_p_p(options_list:=options.split(b' ')), len(options_list))
|
36
|
+
|
17
37
|
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
|
18
38
|
def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
19
39
|
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
|
@@ -40,15 +60,15 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
|
40
60
|
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
41
61
|
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
42
62
|
# -include hiprtc_runtime.h was removed
|
43
|
-
check(
|
63
|
+
check(set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes -Xclang -aux-triple -Xclang x86_64-unknown-linux-gnu".encode())) # noqa: E501
|
44
64
|
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
45
65
|
if status != 0:
|
46
66
|
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
47
67
|
raise RuntimeError("compile failed")
|
48
|
-
check(
|
68
|
+
check(set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
|
49
69
|
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
|
50
70
|
|
51
|
-
check(
|
71
|
+
check(set_options(action_info, b""))
|
52
72
|
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
|
53
73
|
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
|
54
74
|
check(comgr.amd_comgr_release_data(data_src))
|
@@ -56,13 +76,25 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
|
56
76
|
check(comgr.amd_comgr_destroy_action_info(action_info))
|
57
77
|
return ret
|
58
78
|
|
59
|
-
class
|
79
|
+
class HIPCompiler(Compiler):
|
60
80
|
def __init__(self, arch:str):
|
61
81
|
self.arch = arch
|
62
82
|
super().__init__(f"compile_hip_{self.arch}")
|
63
83
|
def compile(self, src:str) -> bytes:
|
64
|
-
try: return compile_hip(src, self.arch)
|
84
|
+
try: return compile_hip(src, self.arch, src.split('\n', 1)[0].strip() == '.text')
|
65
85
|
except RuntimeError as e: raise CompileError(e) from e
|
66
|
-
def disassemble(self, lib:bytes):
|
67
|
-
|
68
|
-
|
86
|
+
def disassemble(self, lib:bytes): amdgpu_disassemble(lib)
|
87
|
+
|
88
|
+
class AMDLLVMCompiler(LLVMCompiler):
|
89
|
+
jit = False
|
90
|
+
target_arch = "AMDGPU"
|
91
|
+
def __init__(self, arch: str):
|
92
|
+
self.arch = arch
|
93
|
+
super().__init__(self.arch, "+cumode")
|
94
|
+
def __reduce__(self): return (AMDLLVMCompiler, (self.arch,))
|
95
|
+
def compile(self, src:str) -> bytes:
|
96
|
+
try: return super().compile(src)
|
97
|
+
except RuntimeError as e:
|
98
|
+
if "undefined value '@llvm.amdgcn." in str(e): raise CompileError(str(e) + "AMD with LLVM backend requires LLVM >= 18") from e
|
99
|
+
raise CompileError(e) from e
|
100
|
+
def disassemble(self, lib:bytes): amdgpu_disassemble(lib)
|
@@ -4,7 +4,7 @@ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
|
|
4
4
|
import tinygrad.runtime.autogen.nvrtc as nvrtc
|
5
5
|
from tinygrad.device import Compiler, CompileError
|
6
6
|
|
7
|
-
PTX = getenv("PTX") #
|
7
|
+
PTX, CUDA_PATH = getenv("PTX"), getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
|
8
8
|
|
9
9
|
def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
10
10
|
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
|
@@ -30,17 +30,18 @@ def pretty_ptx(s):
|
|
30
30
|
s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
|
31
31
|
return s
|
32
32
|
|
33
|
-
def cuda_disassemble(lib, arch):
|
33
|
+
def cuda_disassemble(lib:bytes, arch:str):
|
34
34
|
try:
|
35
35
|
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
36
|
-
with open(fn
|
37
|
-
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn
|
36
|
+
with open(fn, "wb") as f: f.write(lib)
|
37
|
+
subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn], check=False, stderr=subprocess.DEVNULL) # optional ptx -> sass step for CUDA=1
|
38
38
|
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
|
39
39
|
except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
|
40
40
|
|
41
41
|
class CUDACompiler(Compiler):
|
42
42
|
def __init__(self, arch:str, cache_key:str="cuda"):
|
43
|
-
self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}'
|
43
|
+
self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}']
|
44
|
+
self.compile_options += [f"-I{CUDA_PATH}/include"] if CUDA_PATH else ["-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"]
|
44
45
|
nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
|
45
46
|
if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
|
46
47
|
super().__init__(f"compile_{cache_key}_{self.arch}")
|
@@ -51,12 +52,7 @@ class CUDACompiler(Compiler):
|
|
51
52
|
nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
|
52
53
|
return data
|
53
54
|
def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
|
54
|
-
def disassemble(self, lib:bytes):
|
55
|
-
try:
|
56
|
-
fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
57
|
-
with open(fn + ".cubin", "wb") as f: f.write(lib)
|
58
|
-
print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
|
59
|
-
except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
|
55
|
+
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
60
56
|
|
61
57
|
class NVCompiler(CUDACompiler):
|
62
58
|
def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
|
@@ -67,6 +63,7 @@ class PTXCompiler(Compiler):
|
|
67
63
|
self.arch = arch
|
68
64
|
super().__init__(f"compile_{cache_key}_{self.arch}")
|
69
65
|
def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
|
66
|
+
def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
|
70
67
|
|
71
68
|
class NVPTXCompiler(PTXCompiler):
|
72
69
|
def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
|
tinygrad/runtime/support/elf.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
-
import struct
|
1
|
+
import struct
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from tinygrad.helpers import getbits, i2u
|
4
|
+
import tinygrad.runtime.autogen.libc as libc
|
4
5
|
|
5
6
|
@dataclass(frozen=True)
|
6
7
|
class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702
|