tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,35 @@
1
1
  from __future__ import annotations
2
- import ctypes, collections, time, dataclasses, pathlib, fcntl, os
3
- from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
4
- from tinygrad.runtime.autogen.am import am, mp_11_0
5
- from tinygrad.runtime.support.allocator import TLSFAllocator
6
- from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
2
+ import ctypes, collections, dataclasses, functools, os, hashlib
3
+ from tinygrad.helpers import mv_address, getenv, DEBUG, fetch
4
+ from tinygrad.runtime.autogen.am import am
5
+ from tinygrad.runtime.support.hcq import MMIOInterface
6
+ from tinygrad.runtime.support.amd import AMDReg, import_module, import_asic_regs
7
+ from tinygrad.runtime.support.memory import TLSFAllocator, MemoryManager
8
+ from tinygrad.runtime.support.system import System, PCIDevImplBase
9
+ from tinygrad.runtime.support.am.ip import AM_SOC, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
7
10
 
8
11
  AM_DEBUG = getenv("AM_DEBUG", 0)
9
12
 
10
- @dataclasses.dataclass(frozen=True)
11
- class AMRegister:
12
- adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
13
+ @dataclasses.dataclass
14
+ class AMRegister(AMDReg):
15
+ adev:AMDev
13
16
 
14
- def _parse_kwargs(self, **kwargs):
15
- mask, values = 0xffffffff, 0
16
- for k, v in kwargs.items():
17
- if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
18
- m, s = self.reg_fields[k]
19
- if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
20
- mask &= ~m
21
- values |= v << s
22
- return mask, values
17
+ def read(self, inst=0): return self.adev.rreg(self.addr[inst])
18
+ def read_bitfields(self, inst=0) -> dict[str, int]: return self.decode(self.read(inst=inst))
23
19
 
24
- def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
20
+ def write(self, _am_val:int=0, inst=0, **kwargs): self.adev.wreg(self.addr[inst], _am_val | self.encode(**kwargs))
25
21
 
26
- def update(self, **kwargs): self.write(value=self.read(), **kwargs)
27
-
28
- def write(self, value=0, **kwargs):
29
- mask, values = self._parse_kwargs(**kwargs)
30
- self.adev.wreg(self.reg_off, (value & mask) | values)
31
-
32
- def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
22
+ def update(self, inst=0, **kwargs): self.write(self.read(inst=inst) & ~self.fields_mask(*kwargs.keys()), inst=inst, **kwargs)
33
23
 
34
24
  class AMFirmware:
35
25
  def __init__(self, adev):
36
- def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}"
26
+ self.adev = adev
27
+ def fmt_ver(hwip): return '_'.join(map(str, adev.ip_ver[hwip]))
37
28
 
38
29
  # Load SOS firmware
39
30
  self.sos_fw = {}
40
31
 
41
- blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0)
32
+ blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", versioned_header='struct_psp_firmware_header')
42
33
  fw_bin = sos_hdr.psp_fw_bin
43
34
 
44
35
  for fw_i in range(sos_hdr.psp_fw_bin_count):
@@ -48,205 +39,88 @@ class AMFirmware:
48
39
 
49
40
  # Load other fw
50
41
  self.ucode_start: dict[str, int] = {}
51
- self.descs: list[tuple[int, memoryview]] = []
42
+ self.descs: list[tuple[list[int], memoryview]] = []
52
43
 
53
44
  blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0)
54
- self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
45
+ self.smu_psp_desc = self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes, am.GFX_FW_TYPE_SMU)
55
46
 
56
47
  # SDMA firmware
57
- blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0)
58
- self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
59
- self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
48
+ blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", versioned_header='struct_sdma_firmware_header')
49
+ if hdr.header.header_version_major < 3:
50
+ self.descs += [self.desc(blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH1)]
51
+ self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
52
+ else: self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, am.GFX_FW_TYPE_SDMA_UCODE_TH0)]
60
53
 
61
54
  # PFP, ME, MEC firmware
62
- for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
55
+ for (fw_name, fw_cnt) in ([('PFP', 1), ('ME', 1)] if self.adev.ip_ver[am.GC_HWIP] >= (12,0,0) else []) + [('MEC', 1)]:
63
56
  blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
64
57
 
65
58
  # Code part
66
- self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
59
+ self.descs += [self.desc(blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes, getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'))]
67
60
 
68
61
  # Stack
69
- fw_types = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnun}_STACK') for fwnun in range(fw_cnt)]
70
- self.descs += [self.desc(typ, blob, hdr.data_offset_bytes, hdr.data_size_bytes) for typ in fw_types]
62
+ stack_fws = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnum}_STACK') for fwnum in range(fw_cnt)]
63
+ self.descs += [self.desc(blob, hdr.data_offset_bytes, hdr.data_size_bytes, *stack_fws)]
71
64
  self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
72
65
 
73
66
  # IMU firmware
74
67
  blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0)
75
68
  imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
76
- self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
69
+ self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)]
77
70
 
78
71
  # RLC firmware
79
- blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
72
+ blob, hdr0, _hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
80
73
  am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
81
74
 
82
- for mem in ['GPM', 'SRM']:
83
- off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes')
84
- self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)]
85
-
86
75
  for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:
87
76
  off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes')
88
- self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
77
+ self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]
89
78
 
90
- for mem in ['P', 'V']:
91
- off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
92
- self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
79
+ if hdr0.header.header_version_minor == 3:
80
+ for mem in ['P', 'V']:
81
+ off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
82
+ self.descs += [self.desc(blob, off, sz, getattr(am, f'GFX_FW_TYPE_RLC_{mem}'))]
93
83
 
94
- self.descs += [self.desc(am.GFX_FW_TYPE_RLC_G, blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes)]
84
+ self.descs += [self.desc(blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes, am.GFX_FW_TYPE_RLC_G)]
95
85
 
96
- def load_fw(self, fname:str, *headers):
97
- fpath = next(f for loc in ["/lib/firmware/updates/amdgpu/", "/lib/firmware/amdgpu/"] if (f:=pathlib.Path(loc + fname)).exists())
86
+ def load_fw(self, fname:str, *headers, versioned_header:str|None=None):
87
+ fpath = fetch(f"https://gitlab.com/kernel-firmware/linux-firmware/-/raw/45f59212aebd226c7630aff4b58598967c0c8c91/amdgpu/{fname}", subdir="fw")
98
88
  blob = memoryview(bytearray(fpath.read_bytes()))
89
+ if AM_DEBUG >= 1: print(f"am {self.adev.devfmt}: loading firmware {fname}: {hashlib.sha256(blob).hexdigest()}")
90
+ if versioned_header:
91
+ chdr = am.struct_common_firmware_header.from_address(mv_address(blob))
92
+ headers += (getattr(am, versioned_header + f"_v{chdr.header_version_major}_{chdr.header_version_minor}"),)
99
93
  return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])
100
94
 
101
- def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
102
-
103
- @dataclasses.dataclass(frozen=True)
104
- class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
95
+ def desc(self, blob:memoryview, offset:int, size:int, *types:int) -> tuple[list[int], memoryview]: return (list(types), blob[offset:offset+size])
105
96
 
106
97
  class AMPageTableEntry:
107
- def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
98
+ def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.lv, self.entries = adev, paddr, lv, adev.vram.view(paddr, 0x1000, fmt='Q')
108
99
 
109
100
  def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
110
101
  assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
102
+ self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid) | (paddr & 0x0000FFFFFFFFF000)
111
103
 
112
- f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
113
- | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
114
- | ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
115
- | (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
116
- self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
117
-
118
- class AMPageTableTraverseContext:
119
- def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
120
- self.adev, self.vaddr, self.create_pts, self.free_pts = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts
121
- self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
122
-
123
- def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
124
- def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
125
-
126
- def level_down(self):
127
- pt, pte_idx, _ = self.pt_stack[-1]
128
- if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
129
- assert self.create_pts, "Not allowed to create new page table"
130
- pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
131
- entry = pt.entries[pte_idx]
132
-
133
- assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
134
- child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
135
-
136
- self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
137
- return self.pt_stack[-1]
138
-
139
- def _try_free_pt(self) -> bool:
140
- pt, _, _ = self.pt_stack[-1]
141
- if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
142
- self.adev.mm.pfree(pt.paddr)
143
- parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
144
- parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
145
- return True
146
- return False
147
-
148
- def level_up(self):
149
- while self._try_free_pt() or self.pt_stack[-1][1] == 512:
150
- _, pt_cnt, _ = self.pt_stack.pop()
151
- if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
152
-
153
- def next(self, size:int, off=0):
154
- while size > 0:
155
- pt, pte_idx, pte_covers = self.pt_stack[-1]
156
- if self.create_pts:
157
- while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
158
- else:
159
- while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
160
-
161
- entries = min(size // pte_covers, 512 - pte_idx)
162
- assert entries > 0, "Invalid entries"
163
- yield off, pt, pte_idx, entries, pte_covers
164
-
165
- size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
166
- self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
167
- self.level_up()
168
-
169
- class AMMemoryManager:
170
- va_allocator = TLSFAllocator(512 * (1 << 30), base=0x7F0000000000) # global for all devices.
171
-
172
- def __init__(self, adev:AMDev, vram_size:int):
173
- self.adev, self.vram_size = adev, vram_size
174
- self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
175
- self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
176
- self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
177
-
178
- def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
179
- if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
180
-
181
- assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
182
-
183
- ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
184
- for paddr, psize in paddrs:
185
- for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
186
- for pte_off in range(pte_cnt):
187
- assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
188
- pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
189
- uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
190
-
191
- # Invalidate TLB after mappings.
192
- self.adev.gmc.flush_tlb(ip='GC', vmid=0)
193
- self.adev.gmc.flush_tlb(ip='MM', vmid=0)
194
- return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
195
-
196
- def unmap_range(self, vaddr:int, size:int):
197
- if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
198
-
199
- ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
200
- for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
201
- for pte_id in range(pte_idx, pte_idx + pte_cnt):
202
- assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
203
- pt.set_entry(pte_id, paddr=0x0, valid=False)
204
-
205
- @staticmethod
206
- def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
207
-
208
- def valloc(self, size:int, align=0x1000, uncached=False, contigous=False) -> AMMapping:
209
- # Alloc physical memory and map it to the virtual address
210
- va = self.alloc_vaddr(size, align)
211
-
212
- if contigous: paddrs = [(self.palloc(size, zero=True), size)]
213
- else:
214
- paddrs = []
215
- try:
216
- ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
217
- for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
218
- except MemoryError:
219
- for paddr, _ in paddrs: self.pa_allocator.free(paddr)
220
- raise
104
+ def entry(self, entry_id:int) -> int: return self.entries[entry_id]
105
+ def valid(self, entry_id:int) -> bool: return (self.entries[entry_id] & am.AMDGPU_PTE_VALID) != 0
106
+ def address(self, entry_id:int) -> int: return self.entries[entry_id] & 0x0000FFFFFFFFF000
107
+ def is_huge_page(self, entry_id:int) -> bool: return self.lv == am.AMDGPU_VM_PTB or self.adev.gmc.is_pte_huge_page(self.entries[entry_id])
108
+ def supports_huge_page(self, paddr:int): return self.lv >= am.AMDGPU_VM_PDB2
221
109
 
222
- return self.map_range(va, size, paddrs, uncached=uncached)
110
+ class AMMemoryManager(MemoryManager):
111
+ va_allocator = TLSFAllocator(512 * (1 << 30), base=0x200000000000) # global for all devices.
223
112
 
224
- def vfree(self, vm:AMMapping):
225
- self.unmap_range(vm.va_addr, vm.size)
226
- self.va_allocator.free(vm.va_addr)
227
- for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
228
-
229
- def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
230
- assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
231
- paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
232
- if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
233
- return paddr
234
-
235
- def pfree(self, paddr:int): self.pa_allocator.free(paddr)
236
-
237
- class AMDev:
238
- def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
239
- self.devfmt = devfmt
240
- self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
241
-
242
- os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
113
+ def on_range_mapped(self):
114
+ # Invalidate TLB after mappings.
115
+ self.dev.gmc.flush_tlb(ip='GC', vmid=0)
116
+ self.dev.gmc.flush_tlb(ip='MM', vmid=0)
243
117
 
244
- # Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
245
- if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR)
246
- else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
118
+ class AMDev(PCIDevImplBase):
119
+ Version = 0xA0000006
247
120
 
248
- try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
249
- except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.")
121
+ def __init__(self, devfmt, vram:MMIOInterface, doorbell:MMIOInterface, mmio:MMIOInterface, dma_regions:list[tuple[int, MMIOInterface]]|None=None):
122
+ self.devfmt, self.vram, self.doorbell64, self.mmio, self.dma_regions = devfmt, vram, doorbell, mmio, dma_regions
123
+ self.lock_fd = System.flock_acquire(f"am_{self.devfmt}.lock")
250
124
 
251
125
  self._run_discovery()
252
126
  self._build_regs()
@@ -260,30 +134,19 @@ class AMDev:
260
134
  # To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
261
135
  # all blocks that are initialized only during the initial AM boot.
262
136
  # To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
263
- self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
264
- self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
137
+ self.is_booting = True
138
+ self.init_sw(smi_dev=False)
265
139
 
266
- # Memory manager & firmware
267
- self.mm = AMMemoryManager(self, self.vram_size)
268
- self.fw = AMFirmware(self)
269
-
270
- # Initialize IP blocks
271
- self.soc21:AM_SOC21 = AM_SOC21(self)
272
- self.gmc:AM_GMC = AM_GMC(self)
273
- self.ih:AM_IH = AM_IH(self)
274
- self.psp:AM_PSP = AM_PSP(self)
275
- self.smu:AM_SMU = AM_SMU(self)
276
- self.gfx:AM_GFX = AM_GFX(self)
277
- self.sdma:AM_SDMA = AM_SDMA(self)
278
-
279
- if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0):
280
- if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
140
+ self.partial_boot = (self.reg("regSCRATCH_REG7").read() == AMDev.Version) and (getenv("AM_RESET", 0) != 1)
141
+ if self.partial_boot and (self.reg("regGCVM_CONTEXT0_CNTL").read() != 0 or self.reg(self.gmc.pf_status_reg("GC")).read() != 0):
142
+ if DEBUG >= 2: print(f"am {self.devfmt}: Malformed state. Issuing a full reset.")
281
143
  self.partial_boot = False
282
144
 
145
+ # Init hw for IP blocks where it is needed
283
146
  if not self.partial_boot:
284
147
  if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
285
- for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
286
- ip.init()
148
+ for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu]:
149
+ ip.init_hw()
287
150
  if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
288
151
 
289
152
  # Booting done
@@ -291,25 +154,44 @@ class AMDev:
291
154
 
292
155
  # Re-initialize main blocks
293
156
  for ip in [self.gfx, self.sdma]:
294
- ip.init()
157
+ ip.init_hw()
295
158
  if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
296
159
 
297
160
  self.smu.set_clocks(level=-1) # last level, max perf.
298
- self.gfx.set_clockgating_state()
299
- self.reg("regSCRATCH_REG7").write(am_version)
161
+ for ip in [self.soc, self.gfx]: ip.set_clockgating_state()
162
+ self.reg("regSCRATCH_REG7").write(AMDev.Version)
300
163
  if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
301
164
 
165
+ def init_sw(self, smi_dev=False):
166
+ self.smi_dev = smi_dev # During boot only boot memory can be allocated. This flag is to validate this.
167
+
168
+ # Memory manager & firmware
169
+ self.mm = AMMemoryManager(self, self.vram_size, boot_size=(32 << 20), pt_t=AMPageTableEntry, va_shifts=[12, 21, 30, 39], va_bits=48,
170
+ first_lv=am.AMDGPU_VM_PDB2, va_base=AMMemoryManager.va_allocator.base,
171
+ palloc_ranges=[(1 << (i + 12), 0x1000) for i in range(9 * (3 - am.AMDGPU_VM_PDB2), -1, -1)])
172
+ self.fw = AMFirmware(self)
173
+
174
+ # Initialize IP blocks
175
+ self.soc:AM_SOC = AM_SOC(self)
176
+ self.gmc:AM_GMC = AM_GMC(self)
177
+ self.ih:AM_IH = AM_IH(self)
178
+ self.psp:AM_PSP = AM_PSP(self)
179
+ self.smu:AM_SMU = AM_SMU(self)
180
+ self.gfx:AM_GFX = AM_GFX(self)
181
+ self.sdma:AM_SDMA = AM_SDMA(self)
182
+
183
+ # Init sw for all IP blocks
184
+ for ip in [self.soc, self.gmc, self.ih, self.psp, self.smu, self.gfx, self.sdma]: ip.init_sw()
185
+
302
186
  def fini(self):
303
187
  if DEBUG >= 2: print(f"am {self.devfmt}: Finalizing")
304
- for ip in [self.sdma, self.gfx]: ip.fini()
188
+ for ip in [self.sdma, self.gfx]: ip.fini_hw()
305
189
  self.smu.set_clocks(level=0)
306
190
  self.ih.interrupt_handler()
191
+ os.close(self.lock_fd)
307
192
 
308
- def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
309
193
  def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
310
194
 
311
- def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
312
-
313
195
  def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
314
196
 
315
197
  def rreg(self, reg:int) -> int:
@@ -335,62 +217,45 @@ class AMDev:
335
217
  self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
336
218
  self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
337
219
 
338
- def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
339
- for _ in range(timeout):
340
- if ((rval:=reg.read()) & mask) == value: return rval
341
- time.sleep(0.001)
342
- raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
343
-
344
220
  def _run_discovery(self):
345
221
  # NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
346
222
  # The table is located at the end of VRAM - 64KB and is 10KB in size.
347
223
  mmRCC_CONFIG_MEMSIZE = 0xde3
348
224
  self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
225
+ tmr_offset, tmr_size = self.vram_size - (64 << 10), (10 << 10)
349
226
 
350
- bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10)))
351
- ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
352
- assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}"
227
+ self.bhdr = am.struct_binary_header.from_buffer(bytearray(self.vram.view(tmr_offset, tmr_size)[:]))
228
+ ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.IP_DISCOVERY].offset)
229
+ assert self.bhdr.binary_signature == am.BINARY_SIGNATURE and ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE, "discovery signatures mismatch"
353
230
 
354
231
  # Mapping of HW IP to Discovery HW IP
355
232
  hw_id_map = {am.__dict__[x]: int(y) for x,y in am.hw_id_map}
356
- self.regs_offset:dict[int, dict[int, list]] = collections.defaultdict(dict)
357
- self.ip_versions:dict[int, int] = {}
233
+ self.regs_offset:dict[int, dict[int, tuple]] = collections.defaultdict(dict)
234
+ self.ip_ver:dict[int, tuple[int, int, int]] = {}
358
235
 
359
236
  for num_die in range(ihdr.num_dies):
360
- dhdr = am.struct_die_header.from_address(ctypes.addressof(bhdr) + ihdr.die_info[num_die].die_offset)
237
+ dhdr = am.struct_die_header.from_address(ctypes.addressof(self.bhdr) + ihdr.die_info[num_die].die_offset)
361
238
 
362
- ip_offset = ctypes.addressof(bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
239
+ ip_offset = ctypes.addressof(self.bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
363
240
  for _ in range(dhdr.num_ips):
364
241
  ip = am.struct_ip_v4.from_address(ip_offset)
365
242
  ba = (ctypes.c_uint32 * ip.num_base_address).from_address(ip_offset + 8)
366
243
  for hw_ip in range(1, am.MAX_HWIP):
367
244
  if hw_ip in hw_id_map and hw_id_map[hw_ip] == ip.hw_id:
368
- self.regs_offset[hw_ip][ip.instance_number] = list(ba)
369
- self.ip_versions[hw_ip] = int(f"{ip.major:02d}{ip.minor:02d}{ip.revision:02d}")
245
+ self.regs_offset[hw_ip][ip.instance_number] = tuple(list(ba))
246
+ self.ip_ver[hw_ip] = (ip.major, ip.minor, ip.revision)
370
247
 
371
248
  ip_offset += 8 + (8 if ihdr.base_addr_64_bit else 4) * ip.num_base_address
372
249
 
373
- gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
250
+ gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(self.bhdr) + self.bhdr.table_list[am.GC].offset)
374
251
  self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
375
252
 
376
- def _ip_module(self, prefix:str, hwip):
377
- version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
378
- for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
379
- try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
380
- except ImportError: pass
381
- raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
253
+ def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): return import_module(prefix, self.ip_ver[hwip], prever_prefix)
382
254
 
383
255
  def _build_regs(self):
384
- mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
385
- ("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
386
- for base, module in mods:
387
- rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
388
- reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
389
- reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
390
- for k, val in module.__dict__.items():
391
- if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
392
- reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
393
-
394
- for k, regval in module.__dict__.items():
395
- if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
396
- setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))
256
+ mods = [("mp", am.MP0_HWIP), ("hdp", am.HDP_HWIP), ("gc", am.GC_HWIP), ("mmhub", am.MMHUB_HWIP), ("osssys", am.OSSSYS_HWIP),
257
+ ("nbio" if self.ip_ver[am.GC_HWIP] < (12,0,0) else "nbif", am.NBIO_HWIP)]
258
+
259
+ for prefix, hwip in mods:
260
+ self.__dict__.update(import_asic_regs(prefix, self.ip_ver[hwip], cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[hwip])))
261
+ self.__dict__.update(import_asic_regs('mp', (11, 0), cls=functools.partial(AMRegister, adev=self, bases=self.regs_offset[am.MP1_HWIP])))