tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,63 @@
1
+ import functools, struct
2
+ from tinygrad.device import Compiled, Allocator, Compiler
3
+ from tinygrad.renderer.wgsl import WGSLRenderer
4
+ from tinygrad.helpers import round_up
5
+ import wgpu
6
+
7
+ def create_uniform(wgpu_device, val) -> wgpu.GPUBuffer:
8
+ buf = wgpu_device.create_buffer(size=4, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST)
9
+ wgpu_device.queue.write_buffer(buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
10
+ return buf
11
+
12
+ class WebGPUProgram:
13
+ def __init__(self, dev, name:str, lib:bytes):
14
+ (self.dev, self.timestamp_supported) = dev
15
+ self.name, self.lib, self.prg = name, lib, self.dev.create_shader_module(code=lib.decode()) # NOTE: this is the compiler
16
+ def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1), vals=(), wait=False):
17
+ wait = wait and self.timestamp_supported
18
+ binding_layouts = [{"binding": 0, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform }}]
19
+ binding_layouts += [{"binding": i+1, "visibility": wgpu.ShaderStage.COMPUTE,
20
+ "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501
21
+ bindings = [{"binding": 0, "resource": {"buffer": create_uniform(self.dev, float('inf')), "offset": 0, "size": 4}}]
22
+ bindings += [{"binding": i+1, "resource": {"buffer": create_uniform(self.dev, x) if i >= len(bufs) else x, "offset": 0,
23
+ "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501
24
+ bind_group_layout = self.dev.create_bind_group_layout(entries=binding_layouts)
25
+ pipeline_layout = self.dev.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
26
+ bind_group = self.dev.create_bind_group(layout=bind_group_layout, entries=bindings)
27
+ compute_pipeline = self.dev.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
28
+ command_encoder = self.dev.create_command_encoder()
29
+ if wait:
30
+ query_set = self.dev.create_query_set(type=wgpu.QueryType.timestamp, count=2)
31
+ query_buf = self.dev.create_buffer(size=16, usage=wgpu.BufferUsage.QUERY_RESOLVE | wgpu.BufferUsage.COPY_SRC)
32
+ timestamp_writes = {"query_set": query_set, "beginning_of_pass_write_index": 0, "end_of_pass_write_index": 1}
33
+ compute_pass = command_encoder.begin_compute_pass(timestamp_writes=timestamp_writes if wait else None) # pylint: disable=E0606
34
+ compute_pass.set_pipeline(compute_pipeline)
35
+ compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
36
+ compute_pass.dispatch_workgroups(*global_size) # x y z
37
+ compute_pass.end()
38
+ if wait:
39
+ command_encoder.resolve_query_set(query_set=query_set, first_query=0, query_count=2, destination=query_buf, destination_offset=0)
40
+ self.dev.queue.submit([command_encoder.finish()])
41
+ return ((timestamps:=self.dev.queue.read_buffer(query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9 if wait else None
42
+
43
+ # WebGPU buffers have to be 4-byte aligned
44
+ class WebGpuAllocator(Allocator):
45
+ def __init__(self, dev): self.dev = dev
46
+ def _alloc(self, size: int, options):
47
+ return self.dev.create_buffer(size=round_up(size, 4), usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
48
+ def _copyin(self, dest, src: memoryview):
49
+ if src.nbytes % 4:
50
+ padded_src = bytearray(round_up(src.nbytes, 4))
51
+ padded_src[:src.nbytes] = src
52
+ self.dev.queue.write_buffer(dest, 0, padded_src if src.nbytes % 4 else src)
53
+ def _copyout(self, dest: memoryview, src):
54
+ buffer_data = self.dev.queue.read_buffer(src, 0)
55
+ dest[:] = buffer_data[:dest.nbytes] if src._nbytes > dest.nbytes else buffer_data
56
+
57
+ class WebGpuDevice(Compiled):
58
+ def __init__(self, device:str):
59
+ adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
60
+ timestamp_supported = wgpu.FeatureName.timestamp_query in adapter.features
61
+ wgpu_device = adapter.request_device_sync(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
62
+ super().__init__(device, WebGpuAllocator(wgpu_device), WGSLRenderer(), Compiler(),
63
+ functools.partial(WebGPUProgram, (wgpu_device, timestamp_supported)))
@@ -0,0 +1,94 @@
1
+ import collections
2
+ from tinygrad.helpers import round_up
3
+
4
+ class TLSFAllocator:
5
+ """
6
+ The allocator is based on the Two-Level Segregated Fit (TLSF) algorithm. The allocator maintains 2 level of buckets:
7
+ * 1st level is determined by the most significant bit of the size.
8
+ * 2nd level splits the covered memory of 1st level into @lv2_cnt entries.
9
+
10
+ For each allocation request, the allocator searches for the smallest block that can fit the requested size.
11
+ For each deallocation request, the allocator merges the block with its neighbors if they are free.
12
+ """
13
+
14
+ def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
15
+ self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
16
+ self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
17
+ self.lv1_entries:list[int] = [0] * len(self.storage)
18
+
19
+ # self.blocks is more like a linked list, where each entry is a contigous block.
20
+ self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
21
+ self._insert_block(0, size)
22
+
23
+ def lv1(self, size): return size.bit_length()
24
+ def lv2(self, size): return (size - (1 << (size.bit_length() - 1))) // (1 << max(0, size.bit_length() - self.l2_cnt))
25
+
26
+ def _insert_block(self, start:int, size:int, prev:int|None=None):
27
+ if prev is None: prev = self.blocks[start][2]
28
+ self.storage[self.lv1(size)][self.lv2(size)].append(start)
29
+ self.lv1_entries[self.lv1(size)] += 1
30
+ self.blocks[start] = (size, start + size, prev, True)
31
+ return self
32
+
33
+ def _remove_block(self, start:int, size:int, prev:int|None=None):
34
+ if prev is None: prev = self.blocks[start][2]
35
+ self.storage[self.lv1(size)][self.lv2(size)].remove(start)
36
+ self.lv1_entries[self.lv1(size)] -= 1
37
+ self.blocks[start] = (size, start + size, prev, False)
38
+ return self
39
+
40
+ def _split_block(self, start:int, size:int, new_size:int):
41
+ nxt = self.blocks[start][1]
42
+ assert self.blocks[start][3], "block must be free"
43
+ self._remove_block(start, size)._insert_block(start, new_size)._insert_block(start + new_size, size - new_size, prev=start)
44
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start + new_size, self.blocks[nxt][3])
45
+ return self
46
+
47
+ def _merge_right(self, start:int):
48
+ size, nxt, _, is_free = self.blocks[start]
49
+ assert is_free, "block must be free"
50
+
51
+ while is_free and nxt in self.blocks:
52
+ if (blk:=self.blocks[nxt])[3] is False: break
53
+ self._remove_block(start, size)._remove_block(nxt, blk[0])._insert_block(start, size:=size + blk[0])
54
+ assert self.blocks[start][1] == blk[1]
55
+ _, nxt, _, _ = self.blocks.pop(nxt)
56
+
57
+ if nxt in self.blocks: self.blocks[nxt] = (self.blocks[nxt][0], self.blocks[nxt][1], start, self.blocks[nxt][3])
58
+
59
+ def _merge_block(self, start:int):
60
+ # Go left while blocks are free. Then merge all them right.
61
+ while (x:=self.blocks[start][2]) is not None and self.blocks[x][3] is True: start = x
62
+ self._merge_right(start)
63
+
64
+ def alloc(self, req_size:int, align:int=1) -> int:
65
+ req_size = max(self.block_size, req_size) # at least block size.
66
+ size = max(self.block_size, req_size + align - 1)
67
+
68
+ # Round up the allocation size to the next bucket, so any entry there can fit the requested size.
69
+ size = round_up(size, (1 << size.bit_length() - self.l2_cnt))
70
+
71
+ # Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
72
+ for l1 in range(self.lv1(size), len(self.storage)):
73
+ if self.lv1_entries[l1] == 0: continue
74
+ for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
75
+ if len(self.storage[l1][l2]) > 0:
76
+ nsize = self.blocks[self.storage[l1][l2][0]][0]
77
+ assert nsize >= size, "block must be larger"
78
+
79
+ # Block start address.
80
+ start = self.storage[l1][l2][0]
81
+
82
+ # If request contains alignment, split the block into two parts.
83
+ if (new_start:=round_up(start, align)) != start:
84
+ self._split_block(start, nsize, new_start - start)
85
+ start, nsize = new_start, self.blocks[new_start][0]
86
+
87
+ # If the block is larger than the requested size, split it into two parts.
88
+ if nsize > req_size: self._split_block(start, nsize, req_size)
89
+ self._remove_block(start, req_size) # Mark the block as allocated.
90
+ return start + self.base
91
+ raise MemoryError(f"Can't allocate {req_size} bytes")
92
+
93
+ def free(self, start:int):
94
+ self._insert_block(start - self.base, self.blocks[start - self.base][0])._merge_block(start - self.base)
File without changes
@@ -0,0 +1,384 @@
1
+ from __future__ import annotations
2
+ import ctypes, collections, time, dataclasses, pathlib, fcntl, os
3
+ from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
4
+ from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
5
+ from tinygrad.runtime.support.allocator import TLSFAllocator
6
+ from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
7
+
8
+ AM_DEBUG = getenv("AM_DEBUG", 0)
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class AMRegister:
12
+ adev:AMDev; reg_off:int; reg_fields:dict[str, tuple[int, int]] # noqa: E702
13
+
14
+ def _parse_kwargs(self, **kwargs):
15
+ mask, values = 0xffffffff, 0
16
+ for k, v in kwargs.items():
17
+ if k not in self.reg_fields: raise ValueError(f"Unknown register field: {k}. {self.reg_fields.keys()}")
18
+ m, s = self.reg_fields[k]
19
+ if v & (m>>s) != v: raise ValueError(f"Value {v} for {k} is out of range {m=} {s=}")
20
+ mask &= ~m
21
+ values |= v << s
22
+ return mask, values
23
+
24
+ def build(self, **kwargs) -> int: return self._parse_kwargs(**kwargs)[1]
25
+
26
+ def update(self, **kwargs): self.write(value=self.read(), **kwargs)
27
+
28
+ def write(self, value=0, **kwargs):
29
+ mask, values = self._parse_kwargs(**kwargs)
30
+ self.adev.wreg(self.reg_off, (value & mask) | values)
31
+
32
+ def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0]
33
+
34
+ class AMFirmware:
35
+ def __init__(self):
36
+ # Load SOS firmware
37
+ self.sos_fw = {}
38
+
39
+ blob, sos_hdr = self.load_fw("psp_13_0_0_sos.bin", am.struct_psp_firmware_header_v2_0)
40
+ fw_bin = sos_hdr.psp_fw_bin
41
+
42
+ for fw_i in range(sos_hdr.psp_fw_bin_count):
43
+ fw_bin_desc = am.struct_psp_fw_bin_desc.from_address(ctypes.addressof(fw_bin) + fw_i * ctypes.sizeof(am.struct_psp_fw_bin_desc))
44
+ ucode_start_offset = fw_bin_desc.offset_bytes + sos_hdr.header.ucode_array_offset_bytes
45
+ self.sos_fw[fw_bin_desc.fw_type] = blob[ucode_start_offset:ucode_start_offset+fw_bin_desc.size_bytes]
46
+
47
+ # Load other fw
48
+ self.ucode_start: dict[str, int] = {}
49
+ self.descs: list[tuple[int, memoryview]] = []
50
+
51
+ blob, hdr = self.load_fw("smu_13_0_0.bin", am.struct_smc_firmware_header_v1_0)
52
+ self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes)
53
+
54
+ # SDMA firmware
55
+ blob, hdr = self.load_fw("sdma_6_0_0.bin", am.struct_sdma_firmware_header_v2_0)
56
+ self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)]
57
+ self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)]
58
+
59
+ # PFP, ME, MEC firmware
60
+ for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]:
61
+ blob, hdr = self.load_fw(f"gc_11_0_0_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0)
62
+
63
+ # Code part
64
+ self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)]
65
+
66
+ # Stack
67
+ fw_types = [getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}_P{fwnun}_STACK') for fwnun in range(fw_cnt)]
68
+ self.descs += [self.desc(typ, blob, hdr.data_offset_bytes, hdr.data_size_bytes) for typ in fw_types]
69
+ self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32)
70
+
71
+ # IMU firmware
72
+ blob, hdr = self.load_fw("gc_11_0_0_imu.bin", am.struct_imu_firmware_header_v1_0)
73
+ imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes
74
+ self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)]
75
+
76
+ # RLC firmware
77
+ blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw("gc_11_0_0_rlc.bin", am.struct_rlc_firmware_header_v2_0,
78
+ am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
79
+
80
+ for mem in ['GPM', 'SRM']:
81
+ off, sz = getattr(hdr1, f'save_restore_list_{mem.lower()}_offset_bytes'), getattr(hdr1, f'save_restore_list_{mem.lower()}_size_bytes')
82
+ self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_RESTORE_LIST_{mem}_MEM'), blob, off, sz)]
83
+
84
+ for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:
85
+ off, sz = getattr(hdr2, f'rlc_{fmem}_ucode_offset_bytes'), getattr(hdr2, f'rlc_{fmem}_ucode_size_bytes')
86
+ self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
87
+
88
+ for mem in ['P', 'V']:
89
+ off, sz = getattr(hdr3, f'rlc{mem.lower()}_ucode_offset_bytes'), getattr(hdr3, f'rlc{mem.lower()}_ucode_size_bytes')
90
+ self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RLC_{mem}'), blob, off, sz)]
91
+
92
+ self.descs += [self.desc(am.GFX_FW_TYPE_RLC_G, blob, hdr0.header.ucode_array_offset_bytes, hdr0.header.ucode_size_bytes)]
93
+
94
+ def load_fw(self, fname:str, *headers):
95
+ fpath = next(f for loc in ["/lib/firmware/updates/amdgpu/", "/lib/firmware/amdgpu/"] if (f:=pathlib.Path(loc + fname)).exists())
96
+ blob = memoryview(bytearray(fpath.read_bytes()))
97
+ return tuple([blob] + [hdr.from_address(mv_address(blob)) for hdr in headers])
98
+
99
+ def desc(self, typ:int, blob:memoryview, offset:int, size:int) -> tuple[int, memoryview]: return (typ, blob[offset:offset+size])
100
+
101
+ @dataclasses.dataclass(frozen=True)
102
+ class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
103
+
104
+ class AMPageTableEntry:
105
+ def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
106
+
107
+ def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
108
+ assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
109
+
110
+ f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
111
+ | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
112
+ | ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
113
+ | (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
114
+ self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
115
+
116
+ class AMPageTableTraverseContext:
117
+ def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
118
+ self.adev, self.vaddr, self.create_pts, self.free_pts = adev, vaddr - adev.gmc.vm_base, create_pts, free_pts
119
+ self.pt_stack:list[tuple[AMPageTableEntry, int, int]] = [(pt, self._pt_pte_idx(pt, vaddr), self._pt_pte_size(pt))]
120
+
121
+ def _pt_pte_size(self, pt): return (1 << ((9 * (3-pt.lv)) + 12))
122
+ def _pt_pte_idx(self, pt, va): return (va // self._pt_pte_size(pt)) % 512
123
+
124
+ def level_down(self):
125
+ pt, pte_idx, _ = self.pt_stack[-1]
126
+ if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
127
+ assert self.create_pts, "Not allowed to create new page table"
128
+ pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
129
+ entry = pt.entries[pte_idx]
130
+
131
+ assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
132
+ child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
133
+
134
+ self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
135
+ return self.pt_stack[-1]
136
+
137
+ def _try_free_pt(self) -> bool:
138
+ pt, _, _ = self.pt_stack[-1]
139
+ if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
140
+ self.adev.mm.pfree(pt.paddr)
141
+ parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
142
+ parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
143
+ return True
144
+ return False
145
+
146
+ def level_up(self):
147
+ while self._try_free_pt() or self.pt_stack[-1][1] == 512:
148
+ _, pt_cnt, _ = self.pt_stack.pop()
149
+ if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
150
+
151
+ def next(self, size:int, off=0):
152
+ while size > 0:
153
+ pt, pte_idx, pte_covers = self.pt_stack[-1]
154
+ if self.create_pts:
155
+ while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
156
+ else:
157
+ while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
158
+
159
+ entries = min(size // pte_covers, 512 - pte_idx)
160
+ assert entries > 0, "Invalid entries"
161
+ yield off, pt, pte_idx, entries, pte_covers
162
+
163
+ size, off, self.vaddr = size - entries * pte_covers, off + entries * pte_covers, self.vaddr + entries * pte_covers
164
+ self.pt_stack[-1] = (pt, pte_idx + entries, pte_covers)
165
+ self.level_up()
166
+
167
+ class AMMemoryManager:
168
+ va_allocator = TLSFAllocator(512 * (1 << 30), base=0x7F0000000000) # global for all devices.
169
+
170
+ def __init__(self, adev:AMDev, vram_size:int):
171
+ self.adev, self.vram_size = adev, vram_size
172
+ self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
173
+ self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
174
+ self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
175
+
176
+ def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
177
+ if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
178
+
179
+ assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
180
+
181
+ ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
182
+ for paddr, psize in paddrs:
183
+ for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
184
+ for pte_off in range(pte_cnt):
185
+ assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
186
+ pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
187
+ uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
188
+
189
+ # Invalidate TLB after mappings.
190
+ self.adev.gmc.flush_tlb(ip='GC', vmid=0)
191
+ self.adev.gmc.flush_tlb(ip='MM', vmid=0)
192
+ return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
193
+
194
+ def unmap_range(self, vaddr:int, size:int):
195
+ if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
196
+
197
+ ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
198
+ for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
199
+ for pte_id in range(pte_idx, pte_idx + pte_cnt):
200
+ assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
201
+ pt.set_entry(pte_id, paddr=0x0, valid=False)
202
+
203
+ @staticmethod
204
+ def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
205
+
206
+ def valloc(self, size:int, align=0x1000, uncached=False, contigous=False) -> AMMapping:
207
+ # Alloc physical memory and map it to the virtual address
208
+ va = self.alloc_vaddr(size, align)
209
+
210
+ if contigous: paddrs = [(self.palloc(size, zero=True), size)]
211
+ else:
212
+ paddrs = []
213
+ try:
214
+ ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
215
+ for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
216
+ except MemoryError:
217
+ for paddr, _ in paddrs: self.pa_allocator.free(paddr)
218
+ raise
219
+
220
+ return self.map_range(va, size, paddrs, uncached=uncached)
221
+
222
+ def vfree(self, vm:AMMapping):
223
+ self.unmap_range(vm.va_addr, vm.size)
224
+ self.va_allocator.free(vm.va_addr)
225
+ for paddr, _ in vm.paddrs: self.pa_allocator.free(paddr)
226
+
227
+ def palloc(self, size:int, align:int=0x1000, zero=True, boot=False) -> int:
228
+ assert self.adev.is_booting == boot, "During booting, only boot memory can be allocated"
229
+ paddr = (self.boot_allocator if boot else self.pa_allocator).alloc(round_up(size, 0x1000), align)
230
+ if zero: ctypes.memset(self.adev.paddr2cpu(paddr), 0, size)
231
+ return paddr
232
+
233
+ def pfree(self, paddr:int): self.pa_allocator.free(paddr)
234
+
235
+ class AMDev:
236
+ def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
237
+ self.devfmt = devfmt
238
+ self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
239
+
240
+ os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
241
+
242
+ # Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
243
+ if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR)
244
+ else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
245
+
246
+ try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
247
+ except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.")
248
+
249
+ self._run_discovery()
250
+ self._build_regs()
251
+
252
+ # AM boot Process:
253
+ # The GPU being passed can be in one of several states: 1. Not initialized. 2. Initialized by amdgpu. 3. Initialized by AM.
254
+ # The 1st and 2nd states require a full GPU setup since their states are unknown. The 2nd state also requires a mode1 reset to
255
+ # reinitialize all components.
256
+ #
257
+ # The 3rd state can be set up partially to optimize boot time. In this case, only the GFX and SDMA IPs need to be initialized.
258
+ # To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
259
+ # all blocks that are initialized only during the initial AM boot.
260
+ # To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
261
+ self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
262
+ self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
263
+
264
+ # Memory manager & firmware
265
+ self.mm = AMMemoryManager(self, self.vram_size)
266
+ self.fw = AMFirmware()
267
+
268
+ # Initialize IP blocks
269
+ self.soc21:AM_SOC21 = AM_SOC21(self)
270
+ self.gmc:AM_GMC = AM_GMC(self)
271
+ self.ih:AM_IH = AM_IH(self)
272
+ self.psp:AM_PSP = AM_PSP(self)
273
+ self.smu:AM_SMU = AM_SMU(self)
274
+ self.gfx:AM_GFX = AM_GFX(self)
275
+ self.sdma:AM_SDMA = AM_SDMA(self)
276
+
277
+ if self.partial_boot and (self.reg("regCP_MEC_RS64_CNTL").read() & gc_11_0_0.CP_MEC_RS64_CNTL__MEC_HALT_MASK == 0):
278
+ if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
279
+ self.partial_boot = False
280
+
281
+ if not self.partial_boot:
282
+ if self.psp.is_sos_alive() and self.smu.is_smu_alive(): self.smu.mode1_reset()
283
+ for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
284
+ ip.init()
285
+ if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
286
+
287
+ # Booting done
288
+ self.is_booting = False
289
+
290
+ # Re-initialize main blocks
291
+ for ip in [self.gfx, self.sdma]:
292
+ ip.init()
293
+ if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
294
+
295
+ self.smu.set_clocks(level=-1) # last level, max perf.
296
+ self.gfx.set_clockgating_state()
297
+ self.reg("regSCRATCH_REG7").write(am_version)
298
+ if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
299
+
300
+ def fini(self):
301
+ for ip in [self.sdma, self.gfx]: ip.fini()
302
+ self.smu.set_clocks(level=0)
303
+
304
+ def paddr2cpu(self, paddr:int) -> int: return mv_address(self.vram) + paddr
305
+ def paddr2mc(self, paddr:int) -> int: return self.gmc.mc_base + paddr
306
+
307
+ def ip_base(self, ip:str, inst:int, seg:int) -> int: return self.regs_offset[am.__dict__[f"{ip}_HWIP"]][inst][seg]
308
+
309
+ def reg(self, reg:str) -> AMRegister: return self.__dict__[reg]
310
+
311
+ def rreg(self, reg:int) -> int:
312
+ val = self.indirect_rreg(reg * 4) if reg > len(self.mmio) else self.mmio[reg]
313
+ if AM_DEBUG >= 4 and getattr(self, '_prev_rreg', None) != (reg, val): print(f"am {self.devfmt}: Reading register {reg:#x} with value {val:#x}")
314
+ self._prev_rreg = (reg, val)
315
+ return val
316
+
317
+ def wreg(self, reg:int, val:int):
318
+ if AM_DEBUG >= 4: print(f"am {self.devfmt}: Writing register {reg:#x} with value {val:#x}")
319
+ if reg > len(self.mmio): self.indirect_wreg(reg * 4, val)
320
+ else: self.mmio[reg] = val
321
+
322
+ def wreg_pair(self, reg_base:str, lo_suffix:str, hi_suffix:str, val:int):
323
+ self.reg(f"{reg_base}{lo_suffix}").write(val & 0xffffffff)
324
+ self.reg(f"{reg_base}{hi_suffix}").write(val >> 32)
325
+
326
+ def indirect_rreg(self, reg:int) -> int:
327
+ self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
328
+ return self.reg("regBIF_BX_PF0_RSMU_DATA").read()
329
+
330
+ def indirect_wreg(self, reg:int, val:int):
331
+ self.reg("regBIF_BX_PF0_RSMU_INDEX").write(reg)
332
+ self.reg("regBIF_BX_PF0_RSMU_DATA").write(val)
333
+
334
+ def wait_reg(self, reg:AMRegister, value:int, mask=0xffffffff, timeout=10000) -> int:
335
+ for _ in range(timeout):
336
+ if ((rval:=reg.read()) & mask) == value: return rval
337
+ time.sleep(0.001)
338
+ raise RuntimeError(f'wait_reg timeout reg=0x{reg.reg_off:X} mask=0x{mask:X} value=0x{value:X} last_val=0x{rval}')
339
+
340
+ def _run_discovery(self):
341
+ # NOTE: Fixed register to query memory size without known ip bases to find the discovery table.
342
+ # The table is located at the end of VRAM - 64KB and is 10KB in size.
343
+ mmRCC_CONFIG_MEMSIZE = 0xde3
344
+ self.vram_size = self.rreg(mmRCC_CONFIG_MEMSIZE) << 20
345
+
346
+ bhdr = am.struct_binary_header.from_address(self.paddr2cpu(self.vram_size - (64 << 10)))
347
+ ihdr = am.struct_ip_discovery_header.from_address(ctypes.addressof(bhdr) + bhdr.table_list[am.IP_DISCOVERY].offset)
348
+ assert ihdr.signature == am.DISCOVERY_TABLE_SIGNATURE and not ihdr.base_addr_64_bit, f"0x{ihdr.signature:X} != 0x{am.DISCOVERY_TABLE_SIGNATURE:X}"
349
+
350
+ # Mapping of HW IP to Discovery HW IP
351
+ hw_id_map = {am.__dict__[x]: int(y) for x,y in am.hw_id_map}
352
+ self.regs_offset:dict[int, dict[int, list]] = collections.defaultdict(dict)
353
+ self.ip_versions:dict[int, int] = {}
354
+
355
+ for num_die in range(ihdr.num_dies):
356
+ dhdr = am.struct_die_header.from_address(ctypes.addressof(bhdr) + ihdr.die_info[num_die].die_offset)
357
+
358
+ ip_offset = ctypes.addressof(bhdr) + ctypes.sizeof(dhdr) + ihdr.die_info[num_die].die_offset
359
+ for _ in range(dhdr.num_ips):
360
+ ip = am.struct_ip_v4.from_address(ip_offset)
361
+ ba = (ctypes.c_uint32 * ip.num_base_address).from_address(ip_offset + 8)
362
+ for hw_ip in range(1, am.MAX_HWIP):
363
+ if hw_ip in hw_id_map and hw_id_map[hw_ip] == ip.hw_id:
364
+ self.regs_offset[hw_ip][ip.instance_number] = list(ba)
365
+ self.ip_versions[hw_ip] = int(f"{ip.major:02d}{ip.minor:02d}{ip.revision:02d}")
366
+
367
+ ip_offset += 8 + (8 if ihdr.base_addr_64_bit else 4) * ip.num_base_address
368
+
369
+ gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
370
+ self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
371
+
372
+ def _build_regs(self):
373
+ mods = [("MP0", mp_13_0_0), ("MP1", mp_11_0), ("NBIO", nbio_4_3_0), ("MMHUB", mmhub_3_0_0), ("GC", gc_11_0_0), ("OSSSYS", osssys_6_0_0)]
374
+ for base, module in mods:
375
+ rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
376
+ reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
377
+ reg_fields: dict[str, dict[str, tuple]] = collections.defaultdict(dict)
378
+ for k, val in module.__dict__.items():
379
+ if k.endswith("_MASK") and ((rname:=k.split("__")[0]) in reg_names):
380
+ reg_fields[rname][k[2+len(rname):-5].lower()] = (val, module.__dict__.get(f"{k[:-5]}__SHIFT", val.bit_length() - 1))
381
+
382
+ for k, regval in module.__dict__.items():
383
+ if k.startswith(rpref) and not k.endswith("_BASE_IDX") and (base_idx:=getattr(module, f"{k}_BASE_IDX", None)) is not None:
384
+ setattr(self, k, AMRegister(self, self.ip_base(base, 0, base_idx) + regval, reg_fields.get(k[len(rpref):], {})))