tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ import functools, importlib, re, urllib
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from tinygrad.helpers import getbits, round_up, fetch
5
+ from tinygrad.runtime.autogen import pci
6
+ from tinygrad.runtime.support.usb import ASM24Controller
7
+
8
+ @dataclass
9
+ class AMDReg:
10
+ name:str; offset:int; segment:int; fields:dict[str, tuple[int, int]]; bases:dict[int, tuple[int, ...]] # noqa: E702
11
+ def __post_init__(self): self.addr:dict[int, int] = { inst: bases[self.segment] + self.offset for inst, bases in self.bases.items() }
12
+
13
+ def encode(self, **kwargs) -> int: return functools.reduce(int.__or__, (value << self.fields[name][0] for name,value in kwargs.items()), 0)
14
+ def decode(self, val: int) -> dict: return {name:getbits(val, start, end) for name,(start,end) in self.fields.items()}
15
+
16
+ def fields_mask(self, *names) -> int:
17
+ return functools.reduce(int.__or__, ((((1 << (self.fields[nm][1]-self.fields[nm][0]+1)) - 1) << self.fields[nm][0]) for nm in names), 0)
18
+
19
+ @dataclass
20
+ class AMDIP:
21
+ name:str; version:tuple[int, ...]; bases:dict[int, tuple[int, ...]] # noqa: E702
22
+ def __post_init__(self): self.version = fixup_ip_version(self.name, self.version)[0]
23
+
24
+ @functools.cached_property
25
+ def regs(self): return import_asic_regs(self.name, self.version, cls=functools.partial(AMDReg, bases=self.bases))
26
+
27
+ def __getattr__(self, name:str):
28
+ if name in self.regs: return self.regs[name]
29
+
30
+ # NOTE: gfx10 gc registers always start with mm, no reg prefix
31
+ return self.regs[name.replace('reg', 'mm')]
32
+
33
+ def fixup_ip_version(ip:str, version:tuple[int, ...]) -> list[tuple[int, ...]]:
34
+ # override versions
35
+ def _apply_ovrd(ovrd:dict[tuple[int, ...], tuple[int, ...]]) -> tuple[int, ...]:
36
+ for ver, ovrd_ver in ovrd.items():
37
+ if version[:len(ver)] == ver: return ovrd_ver
38
+ return version
39
+
40
+ if ip in ['nbio', 'nbif']: version = _apply_ovrd({(3,3): (2,3,0)})
41
+ elif ip in ['mp', 'smu']: version = _apply_ovrd({(14,0,3): (14,0,2)})
42
+ elif ip in ['gc']: version = _apply_ovrd({(9,5,0): (9,4,3)})
43
+
44
+ return [version, version[:2], version[:2]+(0,), version[:1]+(0, 0)]
45
+
46
+ def header_download(file, name=None, subdir="defines") -> str:
47
+ url = "https://gitlab.com/linux-kernel/linux-next/-/raw/cf6d949a409e09539477d32dbe7c954e4852e744/drivers/gpu/drm/amd"
48
+ return fetch(f"{url}/{file}", name=name, subdir=subdir).read_text()
49
+
50
+ def import_header(path:str):
51
+ t = re.sub(r'//.*|/\*.*?\*/','', header_download(path, subdir="defines"), flags=re.S)
52
+ return {k:int(v,0) for k,v in re.findall(r'\b([A-Za-z_]\w*)\s*=\s*(0x[0-9A-Fa-f]+|\d+)', t)}
53
+
54
+ def import_module(name:str, version:tuple[int, ...], version_prefix:str=""):
55
+ for ver in fixup_ip_version(name, version):
56
+ try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{name}_{version_prefix}{'_'.join(map(str, ver))}")
57
+ except ImportError: pass
58
+ raise ImportError(f"Failed to load autogen module for {name.upper()} {'.'.join(map(str, version))}")
59
+
60
+ def import_soc(ip): return type("SOC", (object,), import_header(f"include/{({9: 'vega10', 10: 'navi10', 11: 'soc21', 12: 'soc24'}[ip[0]])}_enum.h"))
61
+
62
+ def import_asic_regs(prefix:str, version:tuple[int, ...], cls=AMDReg) -> dict[str, AMDReg]:
63
+ def _split_name(name): return name[:(pos:=next((i for i,c in enumerate(name) if c.isupper()), len(name)))], name[pos:]
64
+ def _extract_regs(txt):
65
+ return {m.group(1): int(m.group(2), 0) for line in txt.splitlines() if (m:=re.match(r'#define\s+(\S+)\s+(0x[\da-fA-F]+|\d+)', line))}
66
+ def _download_file(ver, suff) -> str:
67
+ dir_prefix = {"osssys": "oss"}.get(prefix, prefix)
68
+ fetch_name, file_name = f"{prefix}_{'_'.join(map(str, ver))}_{suff}.h", f"{prefix}_{'_'.join(map(str, version))}_{suff}.h"
69
+ return header_download(f"include/asic_reg/{dir_prefix}/{fetch_name}", name=file_name, subdir="asic_regs")
70
+
71
+ for ver in fixup_ip_version(prefix, version):
72
+ try: offs, sh_masks = _extract_regs(_download_file(ver, "offset")), _extract_regs(_download_file(ver, "sh_mask"))
73
+ except urllib.error.HTTPError as e:
74
+ if e.code == 404: continue
75
+ raise
76
+
77
+ offsets = {k:v for k,v in offs.items() if _split_name(k)[0] in {'reg', 'mm'} and not k.endswith('_BASE_IDX')}
78
+ bases = {k[:-len('_BASE_IDX')]:v for k,v in offs.items() if _split_name(k)[0] in {'reg', 'mm'} and k.endswith('_BASE_IDX')}
79
+
80
+ fields: defaultdict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
81
+ for field_name, field_mask in sh_masks.items():
82
+ if not ('__' in field_name and field_name.endswith('_MASK')): continue
83
+ reg_name, reg_field_name = field_name[:-len('_MASK')].split('__')
84
+ fields[reg_name][reg_field_name.lower()] = ((field_mask & -field_mask).bit_length()-1, field_mask.bit_length()-1)
85
+
86
+ # NOTE: Some registers like regGFX_IMU_FUSESTRAP in gc_11_0_0 are missing base idx, just skip them
87
+ return {reg:cls(name=reg, offset=off, segment=bases[reg], fields=fields[_split_name(reg)[1]]) for reg,off in offsets.items() if reg in bases}
88
+ raise ImportError(f"Failed to load ASIC registers for {prefix.upper()} {'.'.join(map(str, version))}")
89
+
90
+ def setup_pci_bars(usb:ASM24Controller, gpu_bus:int, mem_base:int, pref_mem_base:int) -> dict[int, tuple[int, int]]:
91
+ for bus in range(gpu_bus):
92
+ # All 3 values must be written at the same time.
93
+ buses = (0 << 0) | ((bus+1) << 8) | ((gpu_bus) << 16)
94
+ usb.pcie_cfg_req(pci.PCI_PRIMARY_BUS, bus=bus, dev=0, fn=0, value=buses, size=4)
95
+
96
+ usb.pcie_cfg_req(pci.PCI_MEMORY_BASE, bus=bus, dev=0, fn=0, value=(mem_base>>16) & 0xffff, size=2)
97
+ usb.pcie_cfg_req(pci.PCI_MEMORY_LIMIT, bus=bus, dev=0, fn=0, value=0xffff, size=2)
98
+ usb.pcie_cfg_req(pci.PCI_PREF_MEMORY_BASE, bus=bus, dev=0, fn=0, value=(pref_mem_base>>16) & 0xffff, size=2)
99
+ usb.pcie_cfg_req(pci.PCI_PREF_MEMORY_LIMIT, bus=bus, dev=0, fn=0, value=0xffff, size=2)
100
+ usb.pcie_cfg_req(pci.PCI_PREF_BASE_UPPER32, bus=bus, dev=0, fn=0, value=pref_mem_base >> 32, size=4)
101
+ usb.pcie_cfg_req(pci.PCI_PREF_LIMIT_UPPER32, bus=bus, dev=0, fn=0, value=0xffffffff, size=4)
102
+
103
+ usb.pcie_cfg_req(pci.PCI_COMMAND, bus=bus, dev=0, fn=0, value=pci.PCI_COMMAND_IO | pci.PCI_COMMAND_MEMORY | pci.PCI_COMMAND_MASTER, size=1)
104
+
105
+ # resize bar 0
106
+ cap_ptr = 0x100
107
+ while cap_ptr:
108
+ if pci.PCI_EXT_CAP_ID(hdr:=usb.pcie_cfg_req(cap_ptr, bus=gpu_bus, dev=0, fn=0, size=4)) == pci.PCI_EXT_CAP_ID_REBAR:
109
+ cap = usb.pcie_cfg_req(cap_ptr + 0x04, bus=gpu_bus, dev=0, fn=0, size=4)
110
+ new_ctrl = (usb.pcie_cfg_req(cap_ptr + 0x08, bus=gpu_bus, dev=0, fn=0, size=4) & ~0x1F00) | ((int(cap >> 4).bit_length() - 1) << 8)
111
+ usb.pcie_cfg_req(cap_ptr + 0x08, bus=gpu_bus, dev=0, fn=0, value=new_ctrl, size=4)
112
+
113
+ cap_ptr = pci.PCI_EXT_CAP_NEXT(hdr)
114
+
115
+ mem_space_addr, bar_off, bars = [mem_base, pref_mem_base], 0, {}
116
+ while bar_off < 24:
117
+ cfg = usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, size=4)
118
+ bar_mem, bar_64 = bool(cfg & pci.PCI_BASE_ADDRESS_MEM_PREFETCH), cfg & pci.PCI_BASE_ADDRESS_MEM_TYPE_64
119
+
120
+ if (cfg & pci.PCI_BASE_ADDRESS_SPACE) == pci.PCI_BASE_ADDRESS_SPACE_MEMORY:
121
+ usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, value=0xffffffff, size=4)
122
+ lo = (usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, size=4) & 0xfffffff0)
123
+
124
+ if bar_64: usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, value=0xffffffff, size=4)
125
+ hi = (usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, size=4) if bar_64 else 0)
126
+
127
+ bar_size = ((~(((hi << 32) | lo) & ~0xf)) + 1) & (0xffffffffffffffff if bar_64 else 0xffffffff)
128
+
129
+ usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off, bus=gpu_bus, dev=0, fn=0, value=mem_space_addr[bar_mem] & 0xffffffff, size=4)
130
+ if bar_64: usb.pcie_cfg_req(pci.PCI_BASE_ADDRESS_0 + bar_off + 4, bus=gpu_bus, dev=0, fn=0, value=mem_space_addr[bar_mem] >> 32, size=4)
131
+
132
+ bars[bar_off // 4] = (mem_space_addr[bar_mem], bar_size)
133
+ mem_space_addr[bar_mem] += round_up(bar_size, 2 << 20)
134
+
135
+ bar_off += 8 if bar_64 else 4
136
+
137
+ usb.pcie_cfg_req(pci.PCI_COMMAND, bus=gpu_bus, dev=0, fn=0, value=pci.PCI_COMMAND_IO | pci.PCI_COMMAND_MEMORY | pci.PCI_COMMAND_MASTER, size=1)
138
+ return bars
@@ -1,6 +1,20 @@
1
1
  import ctypes, subprocess
2
2
  import tinygrad.runtime.autogen.comgr as comgr
3
+ assert comgr.AMD_COMGR_LANGUAGE_HIP == 4
4
+ try:
5
+ comgr.amd_comgr_get_version(ctypes.byref(major:=ctypes.c_uint64()), ctypes.byref(minor:=ctypes.c_uint64()))
6
+ if major.value >= 3:
7
+ # in comgr 3 the values of enums in headers were changed: https://github.com/ROCm/llvm-project/issues/272
8
+ import tinygrad.runtime.autogen.comgr_3 as comgr # type: ignore[no-redef]
9
+ assert comgr.AMD_COMGR_LANGUAGE_HIP == 3
10
+ except AttributeError: pass # ignore if ROCm isn't installed
3
11
  from tinygrad.device import Compiler, CompileError
12
+ from tinygrad.runtime.ops_llvm import LLVMCompiler
13
+ from tinygrad.helpers import OSX, to_char_p_p
14
+
15
+ def amdgpu_disassemble(lib:bytes):
16
+ asm = subprocess.check_output(["llvm-objdump" if OSX else "/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
17
+ print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
4
18
 
5
19
  def check(status):
6
20
  if status != 0:
@@ -14,6 +28,12 @@ def _get_comgr_data(data_set, data_type):
14
28
  check(comgr.amd_comgr_release_data(data_exec))
15
29
  return bytes(dat)
16
30
 
31
+ # amd_comgr_action_info_set_options was deprecated
32
+ def set_options(action_info, options:bytes):
33
+ # TODO: this type should be correct in the autogen stub
34
+ comgr.amd_comgr_action_info_set_option_list.argtypes = [comgr.amd_comgr_action_info_t, ctypes.POINTER(ctypes.POINTER(ctypes.c_char)), comgr.size_t]
35
+ return comgr.amd_comgr_action_info_set_option_list(action_info, to_char_p_p(options_list:=options.split(b' ')), len(options_list))
36
+
17
37
  # AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
18
38
  def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
19
39
  check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
@@ -40,15 +60,15 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
40
60
  check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
41
61
  check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
42
62
  # -include hiprtc_runtime.h was removed
43
- check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
63
+ check(set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes -Xclang -aux-triple -Xclang x86_64-unknown-linux-gnu".encode())) # noqa: E501
44
64
  status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
45
65
  if status != 0:
46
66
  print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
47
67
  raise RuntimeError("compile failed")
48
- check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
68
+ check(set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
49
69
  check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
50
70
 
51
- check(comgr.amd_comgr_action_info_set_options(action_info, b""))
71
+ check(set_options(action_info, b""))
52
72
  check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
53
73
  ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
54
74
  check(comgr.amd_comgr_release_data(data_src))
@@ -56,13 +76,25 @@ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
56
76
  check(comgr.amd_comgr_destroy_action_info(action_info))
57
77
  return ret
58
78
 
59
- class AMDCompiler(Compiler):
79
+ class HIPCompiler(Compiler):
60
80
  def __init__(self, arch:str):
61
81
  self.arch = arch
62
82
  super().__init__(f"compile_hip_{self.arch}")
63
83
  def compile(self, src:str) -> bytes:
64
- try: return compile_hip(src, self.arch)
84
+ try: return compile_hip(src, self.arch, src.split('\n', 1)[0].strip() == '.text')
65
85
  except RuntimeError as e: raise CompileError(e) from e
66
- def disassemble(self, lib:bytes):
67
- asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
68
- print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
86
+ def disassemble(self, lib:bytes): amdgpu_disassemble(lib)
87
+
88
+ class AMDLLVMCompiler(LLVMCompiler):
89
+ jit = False
90
+ target_arch = "AMDGPU"
91
+ def __init__(self, arch: str):
92
+ self.arch = arch
93
+ super().__init__(self.arch, "+cumode")
94
+ def __reduce__(self): return (AMDLLVMCompiler, (self.arch,))
95
+ def compile(self, src:str) -> bytes:
96
+ try: return super().compile(src)
97
+ except RuntimeError as e:
98
+ if "undefined value '@llvm.amdgcn." in str(e): raise CompileError(str(e) + "AMD with LLVM backend requires LLVM >= 18") from e
99
+ raise CompileError(e) from e
100
+ def disassemble(self, lib:bytes): amdgpu_disassemble(lib)
@@ -4,7 +4,7 @@ from tinygrad.helpers import to_char_p_p, colored, init_c_var, getenv
4
4
  import tinygrad.runtime.autogen.nvrtc as nvrtc
5
5
  from tinygrad.device import Compiler, CompileError
6
6
 
7
- PTX = getenv("PTX") # this shouldn't be here, in fact, it shouldn't exist
7
+ PTX, CUDA_PATH = getenv("PTX"), getenv("CUDA_PATH", "") # PTX shouldn't be here, in fact, it shouldn't exist
8
8
 
9
9
  def _get_bytes(arg, get_str, get_sz, check) -> bytes:
10
10
  sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
@@ -30,17 +30,18 @@ def pretty_ptx(s):
30
30
  s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
31
31
  return s
32
32
 
33
- def cuda_disassemble(lib, arch):
33
+ def cuda_disassemble(lib:bytes, arch:str):
34
34
  try:
35
35
  fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
36
- with open(fn + ".ptx", "wb") as f: f.write(lib)
37
- subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
36
+ with open(fn, "wb") as f: f.write(lib)
37
+ subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn], check=False, stderr=subprocess.DEVNULL) # optional ptx -> sass step for CUDA=1
38
38
  print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
39
39
  except Exception as e: print("Failed to generate SASS", str(e), "Make sure your PATH contains ptxas/nvdisasm binary of compatible version.")
40
40
 
41
41
  class CUDACompiler(Compiler):
42
42
  def __init__(self, arch:str, cache_key:str="cuda"):
43
- self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
43
+ self.arch, self.compile_options = arch, [f'--gpu-architecture={arch}']
44
+ self.compile_options += [f"-I{CUDA_PATH}/include"] if CUDA_PATH else ["-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"]
44
45
  nvrtc_check(nvrtc.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
45
46
  if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
46
47
  super().__init__(f"compile_{cache_key}_{self.arch}")
@@ -51,12 +52,7 @@ class CUDACompiler(Compiler):
51
52
  nvrtc_check(nvrtc.nvrtcDestroyProgram(ctypes.byref(prog)))
52
53
  return data
53
54
  def compile(self, src:str) -> bytes: return self._compile_program(src, nvrtc.nvrtcGetPTX, nvrtc.nvrtcGetPTXSize)
54
- def disassemble(self, lib:bytes):
55
- try:
56
- fn = (pathlib.Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
57
- with open(fn + ".cubin", "wb") as f: f.write(lib)
58
- print(subprocess.check_output(["nvdisasm", fn+".cubin"]).decode('utf-8'))
59
- except Exception as e: print("Failed to disasm cubin:", str(e), "Make sure your PATH contains nvdisasm binary of compatible version.")
55
+ def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
60
56
 
61
57
  class NVCompiler(CUDACompiler):
62
58
  def __init__(self, arch:str): super().__init__(arch, cache_key="nv")
@@ -67,6 +63,7 @@ class PTXCompiler(Compiler):
67
63
  self.arch = arch
68
64
  super().__init__(f"compile_{cache_key}_{self.arch}")
69
65
  def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", "7.8" if self.arch >= "sm_89" else "7.5").encode()
66
+ def disassemble(self, lib:bytes): cuda_disassemble(lib, self.arch)
70
67
 
71
68
  class NVPTXCompiler(PTXCompiler):
72
69
  def __init__(self, arch:str): super().__init__(arch, cache_key="nv_ptx")
@@ -1,6 +1,7 @@
1
- import struct, tinygrad.runtime.autogen.libc as libc
1
+ import struct
2
2
  from dataclasses import dataclass
3
3
  from tinygrad.helpers import getbits, i2u
4
+ import tinygrad.runtime.autogen.libc as libc
4
5
 
5
6
  @dataclass(frozen=True)
6
7
  class ElfSection: name:str; header:libc.Elf64_Shdr; content:bytes # noqa: E702