tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -2,14 +2,14 @@
2
2
  # a python uops emulator
3
3
  # works to test the tensor cores, and all the uops in general
4
4
  # this is the (living) definition of uops
5
- from typing import Optional, Any, TYPE_CHECKING
5
+ from typing import Any, TYPE_CHECKING
6
6
  import pickle, base64, itertools, time, struct, sys
7
7
  from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
8
8
  from tinygrad.helpers import all_same, getenv, flatten, get_single_element
9
9
  from tinygrad.device import Compiled, Compiler, Allocator
10
- from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
10
+ from tinygrad.codegen.opt import tc
11
+ from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp
11
12
  from tinygrad.renderer import Renderer
12
- from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
13
13
 
14
14
  def _load(m, i):
15
15
  if i is None: return 0.0
@@ -17,8 +17,8 @@ def _load(m, i):
17
17
  return m[i]
18
18
 
19
19
  def load(inp, j=0):
20
- if len(inp) == 3: return [_load(m, x+j if x is not None else None) if gate else default for (m,x),default,gate in zip(*inp)]
21
- return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
20
+ if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
21
+ return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]]
22
22
 
23
23
  def _store(m, i, v):
24
24
  if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
@@ -26,7 +26,7 @@ def _store(m, i, v):
26
26
 
27
27
  class PythonProgram:
28
28
  def __init__(self, name:str, lib:bytes):
29
- self.uops: list[tuple[Ops, Optional[DType], list[int], Any]] = pickle.loads(lib)
29
+ self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib)
30
30
  def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
31
31
  st = time.perf_counter()
32
32
  warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
@@ -40,79 +40,74 @@ class PythonProgram:
40
40
  loop_ends: dict[int, int] = {}
41
41
  while i < len(self.uops):
42
42
  uop, dtype, idp, arg = self.uops[i]
43
- void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME}
44
- if uop is Ops.DEFINE_ACC: idp = [idp[0]]
43
+ void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
45
44
  inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
46
45
  dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
47
46
  if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
48
- if uop is Ops.STORE:
49
- if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
50
- if dtp[1].count > 1:
51
- for j,val in enumerate(inp[1]):
52
- for (m,o),v,g in zip(inp[0], val, inp[2]):
53
- if g: _store(m, o+j, v)
54
- else:
55
- for (m,o),v,g in zip(*inp):
56
- if g: _store(m, o, v)
57
- i += 1
58
- continue
59
47
  if uop is Ops.ENDRANGE:
60
48
  loop_ends[idp[0]] = i
61
49
  i = idp[0]
62
50
  continue
63
- if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME):
51
+ if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
64
52
  # in the python emulator, the warp is always in sync
65
53
  i += 1
66
54
  continue
67
55
  assert dtype is not None, f"{uop} is missing a dtype"
68
56
  dl[i] = dtype
69
- if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
70
- assert dtype.fmt is not None and isinstance(dtype, PtrDType)
57
+ if uop is Ops.STORE:
58
+ for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
59
+ for (m,o,g),v in zip(inp[0], val):
60
+ if g: _store(m, o+j, v)
61
+ i += 1
62
+ continue
63
+ if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
64
+ assert isinstance(dtype, PtrDType), dtype
65
+ if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported")
71
66
  if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
72
- buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
73
- ul[i] = [buf.cast(dtype.fmt)] * warp_size
67
+ if uop is Ops.DEFINE_REG:
68
+ # REGs are per thread
69
+ ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)]
70
+ else:
71
+ buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
72
+ ul[i] = [buf.cast(dtype.fmt)] * warp_size
74
73
  elif uop is Ops.DEFINE_VAR:
75
74
  ul[i] = [pvals.pop(0)] * warp_size
76
75
  elif uop is Ops.SPECIAL:
77
76
  if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
78
77
  elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
79
78
  elif uop is Ops.CONST: ul[i] = [arg] * warp_size
80
- elif uop is Ops.DEFINE_ACC:
81
- ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
82
79
  elif uop is Ops.INDEX:
83
- ret = []
80
+ ret:list = []
84
81
  if isinstance(dtp[0], ImageDType):
85
82
  for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
86
83
  if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
87
84
  else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
88
85
  else:
89
86
  for m,o in zip(inp[0], inp[1]): ret.append((m,o))
90
- ul[i] = ret
87
+ ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last
91
88
  elif uop is Ops.CAST and isinstance(dtype, PtrDType):
92
89
  ul[i] = inp[0]
93
90
  elif uop is Ops.RANGE:
94
- if i not in ul: ul[i] = [inp[0][0]] * warp_size
91
+ if i not in ul: ul[i] = [0] * warp_size
95
92
  else:
96
93
  for j in range(len(ul[i])):
97
94
  ul[i][j] += 1
98
- if ul[i][0] == inp[1][0]:
95
+ if ul[i][0] == inp[0][0]:
99
96
  del ul[i]
100
97
  i = loop_ends[i] + 1
101
98
  continue
102
99
  elif uop is Ops.VECTORIZE: ul[i] = inp
103
- elif uop in {Ops.CAST, Ops.BITCAST}:
100
+ elif uop is Ops.BITCAST:
104
101
  assert dtp[0].fmt and dtype.fmt
105
102
  pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
106
- if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
107
- else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
103
+ ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
104
+ elif uop is Ops.CAST:
105
+ ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
108
106
  elif uop is Ops.LOAD:
109
107
  if dtype.count > 1:
110
108
  ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
111
109
  else:
112
110
  ul[i] = load(inp)
113
- elif uop is Ops.ASSIGN:
114
- for j in range(len(inp[0])): inp[0][j] = inp[1][j]
115
- ul[i] = inp[0]
116
111
  elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
117
112
  elif uop is Ops.WMMA:
118
113
  # here are the models for the WMMA instruction on the different hardware
@@ -129,14 +124,27 @@ class PythonProgram:
129
124
  out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
130
125
  return out
131
126
 
127
+ first_src_dtype = self.uops[idp[0]][1]
128
+ assert isinstance(first_src_dtype, DType) # mypy
129
+ dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
132
130
  # TODO: refactor these to a shared TensorCoreLayout in kernel.py
133
- if arg[4] == "METAL":
131
+ if device == "METAL":
134
132
  # A (2 elements on 32 threads): row major
135
133
  def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
136
134
  # (i, j), C, D (2 elements on 32 threads): row major same as A/B
137
135
  def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
138
136
  ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
139
- elif arg[4] == "AMD":
137
+ elif device == "AMD" and threads == 64:
138
+ def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row]
139
+ def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
140
+ def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
141
+ ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
142
+ elif device == "AMD" and len(inp[0]) == 8: # RDNA4
143
+ def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
144
+ def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
145
+ def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
146
+ ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
147
+ elif device == "AMD":
140
148
  # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
141
149
  def a_elem(x, k, row, goff):
142
150
  assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
@@ -145,27 +153,27 @@ class PythonProgram:
145
153
  def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
146
154
  def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
147
155
  ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
148
- elif arg[4] == "CUDA":
156
+ elif device == "CUDA":
149
157
  # (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
150
158
  def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
151
159
 
152
- if arg[1] == (8,16,16):
160
+ if dims == (8,16,16):
153
161
  def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
154
162
  def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
155
163
  ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
156
164
 
157
- elif arg[1] == (8,16,8) and arg[2] == dtypes.half:
165
+ elif dims == (8,16,8) and dtype_in == dtypes.half:
158
166
  def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
159
167
  def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
160
168
  ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
161
169
 
162
- elif arg[1] == (8,16,8) and arg[2] == dtypes.float:
170
+ elif dims == (8,16,8) and dtype_in == dtypes.float:
163
171
  def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
164
172
  def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
165
173
  ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
166
174
 
167
175
  else: raise NotImplementedError(f"unimplemented tensor core {arg}")
168
- elif arg[4] == "INTEL":
176
+ elif device == "INTEL":
169
177
  # A (16 elements on 8 threads)
170
178
  def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
171
179
  # B (16 elements on 8 threads)
@@ -173,7 +181,7 @@ class PythonProgram:
173
181
  # C, D (8 elements on 8 threads)
174
182
  def c_map(lane, elem): return (lane, elem)
175
183
  ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
176
- elif arg[4] == "CPU":
184
+ elif device == "CPU":
177
185
  def elem(x, col, row, _): return x[col+row][0] # k is always 0
178
186
  def c_map(_, elem): return (elem%16, elem//16)
179
187
  ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
@@ -189,12 +197,14 @@ class PythonProgram:
189
197
  class PythonRenderer(Renderer):
190
198
  device = "PYTHON"
191
199
  def __init__(self):
192
- if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
193
- if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
194
- if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
195
- if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
196
- if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
197
- if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores
200
+ if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
201
+ if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
202
+ if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna
203
+ if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4
204
+ if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
205
+ if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
206
+ if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
207
+ if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx
198
208
 
199
209
  def render(self, uops:list[UOp]) -> str:
200
210
  lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
@@ -203,10 +213,10 @@ class PythonRenderer(Renderer):
203
213
  class PythonCompiler(Compiler):
204
214
  def compile(self, src:str) -> bytes: return base64.b64decode(src)
205
215
 
206
- class PythonAllocator(Allocator):
216
+ class PythonAllocator(Allocator['PythonDevice']):
207
217
  def _alloc(self, size, options): return memoryview(bytearray(size))
208
218
  def _copyin(self, dest, src:memoryview): dest[:] = src
209
219
  def _copyout(self, dest:memoryview, src): dest[:] = src
210
220
 
211
221
  class PythonDevice(Compiled):
212
- def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
222
+ def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
- import os, ctypes, functools, mmap, struct, array, math, sys
2
+ import os, ctypes, functools, mmap, struct, array, math, sys, weakref
3
3
  assert sys.platform != 'win32'
4
4
  from types import SimpleNamespace
5
5
  from typing import Any, cast
6
6
  from tinygrad.device import BufferSpec
7
7
  from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
8
- from tinygrad.runtime.support.hcq import HWInterface
8
+ from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface
9
9
  from tinygrad.runtime.autogen import kgsl, adreno
10
10
  from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
11
11
  from tinygrad.renderer.cstyle import QCOMRenderer
@@ -37,17 +37,12 @@ class QCOMCompiler(CLCompiler):
37
37
  def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib)
38
38
 
39
39
  class QCOMSignal(HCQSignal):
40
- def __init__(self, base_addr:int|None=None, **kwargs):
41
- super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2)
42
-
43
- def __del__(self):
44
- if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr)
40
+ def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
45
41
 
46
42
  def _sleep(self, time_spent_waiting_ms:int):
47
- # Sleep only for only timeline signals. Do it immediately to free cpu.
48
- if self.timeline_for_device is not None:
49
- kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.timeline_for_device.fd, context_id=self.timeline_for_device.ctx,
50
- timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
43
+ # Sleep only for timeline signals. Do it immediately to free cpu.
44
+ if self.is_timeline and self.owner is not None:
45
+ kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
51
46
 
52
47
  class QCOMComputeQueue(HWQueue):
53
48
  def __del__(self):
@@ -135,7 +130,7 @@ class QCOMComputeQueue(HWQueue):
135
130
 
136
131
  self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
137
132
  state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4),
138
- *data64_le(args_state.ptr))
133
+ *data64_le(args_state.buf.va_addr))
139
134
  self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
140
135
  state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128),
141
136
  *data64_le(prg.lib_gpu.va_addr))
@@ -148,21 +143,21 @@ class QCOMComputeQueue(HWQueue):
148
143
  if args_state.prg.samp_cnt > 0:
149
144
  self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
150
145
  state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
151
- *data64_le(args_state.ptr + args_state.prg.samp_off))
152
- self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
146
+ *data64_le(args_state.buf.va_addr + args_state.prg.samp_off))
147
+ self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.buf.va_addr + args_state.prg.samp_off))
153
148
  self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev.border_color_buf.va_addr))
154
149
 
155
150
  if args_state.prg.tex_cnt > 0:
156
151
  self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
157
152
  state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)),
158
- *data64_le(args_state.ptr + args_state.prg.tex_off))
159
- self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
153
+ *data64_le(args_state.buf.va_addr + args_state.prg.tex_off))
154
+ self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.buf.va_addr + args_state.prg.tex_off))
160
155
 
161
156
  if args_state.prg.ibo_cnt > 0:
162
157
  self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT,
163
158
  state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt),
164
- *data64_le(args_state.ptr + args_state.prg.ibo_off))
165
- self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ptr + args_state.prg.ibo_off))
159
+ *data64_le(args_state.buf.va_addr + args_state.prg.ibo_off))
160
+ self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.buf.va_addr + args_state.prg.ibo_off))
166
161
 
167
162
  self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
168
163
  qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
@@ -171,24 +166,24 @@ class QCOMComputeQueue(HWQueue):
171
166
  return self
172
167
 
173
168
  class QCOMArgsState(HCQArgsState):
174
- def __init__(self, ptr:int, prg:QCOMProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()):
175
- super().__init__(ptr, prg, bufs, vals=vals)
169
+ def __init__(self, buf:HCQBuffer, prg:QCOMProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()):
170
+ super().__init__(buf, prg, bufs, vals=vals)
176
171
 
177
172
  if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}')
178
173
 
179
- self.buf_info, self.args_info, self.args_view = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):], to_mv(ptr, prg.kernargs_alloc_size).cast('Q')
174
+ self.buf_info, self.args_info = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):]
180
175
 
181
- ctypes.memset(self.ptr, 0, prg.kernargs_alloc_size)
182
- for cnst_val, cnst_off, cnst_sz in prg.consts_info: to_mv(self.ptr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
176
+ ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
177
+ for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
183
178
 
184
- if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
179
+ if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
185
180
  for i, b in enumerate(bufs):
186
181
  if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
187
182
  obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo
188
- to_mv(self.ptr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
189
- self.bind_sints_to_ptr(b.va_addr, ptr=self.ptr + self.buf_info[i].offset + (0 if self.buf_info[i].type is BUFTYPE_BUF else 16), fmt='Q')
183
+ to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
184
+ self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16))
190
185
 
191
- for i, v in enumerate(vals): self.bind_sints_to_ptr(v, ptr=self.ptr + self.args_info[i].offset, fmt='I')
186
+ for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=self.args_info[i].offset)
192
187
 
193
188
  class QCOMProgram(HCQProgram):
194
189
  def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
@@ -196,7 +191,7 @@ class QCOMProgram(HCQProgram):
196
191
  self.name, self.lib = name, lib
197
192
  self._parse_lib()
198
193
 
199
- self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferSpec(cpu_access=True, nolru=True))
194
+ self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, buf_spec:=BufferSpec(cpu_access=True, nolru=True))
200
195
  to_mv(cast(int, self.lib_gpu.va_addr), self.image_size)[:] = self.image
201
196
 
202
197
  self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
@@ -208,6 +203,7 @@ class QCOMProgram(HCQProgram):
208
203
 
209
204
  kernargs_alloc_size = round_up(2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10, 0x100)
210
205
  super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
206
+ weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
211
207
 
212
208
  def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
213
209
  if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
@@ -265,9 +261,6 @@ class QCOMProgram(HCQProgram):
265
261
  reg_desc_off = _read_lib(0x34)
266
262
  self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18)
267
263
 
268
- def __del__(self):
269
- if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
270
-
271
264
  class QCOMTextureInfo:
272
265
  def __init__(self, pitch:int, real_stride:int, desc:list[int], ibo:list[int]):
273
266
  self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
@@ -285,7 +278,7 @@ class QCOMAllocator(HCQAllocatorBase):
285
278
  pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
286
279
  size = pitch * imgh
287
280
 
288
- buf = HCQBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
281
+ buf = HCQBuffer(options.external_ptr, size, owner=self.dev) if options.external_ptr else self.dev._gpu_alloc(size)
289
282
 
290
283
  if options.image is not None:
291
284
  tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
@@ -320,16 +313,12 @@ class QCOMAllocator(HCQAllocatorBase):
320
313
  self.dev._gpu_free(opaque)
321
314
 
322
315
  class QCOMDevice(HCQCompiled):
323
- signals_page: Any = None
324
- signals_pool: list[int] = []
325
316
  gpu_id: int = 0
326
317
  dummy_addr: int = 0
327
318
 
328
319
  def __init__(self, device:str=""):
329
- self.fd = HWInterface('/dev/kgsl-3d0', os.O_RDWR)
320
+ self.fd = FileIOInterface('/dev/kgsl-3d0', os.O_RDWR)
330
321
  QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr)
331
- QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
332
- QCOMDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, self.signals_page.size, 16)]
333
322
 
334
323
  flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \
335
324
  | kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)
@@ -363,11 +352,11 @@ class QCOMDevice(HCQCompiled):
363
352
  va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000)
364
353
 
365
354
  if fill_zeroes: ctypes.memset(va_addr, 0, size)
366
- return HCQBuffer(va_addr=va_addr, size=size, meta=alloc)
355
+ return HCQBuffer(va_addr=va_addr, size=size, meta=alloc, view=MMIOInterface(va_addr, size, fmt='B'), owner=self)
367
356
 
368
357
  def _gpu_free(self, mem:HCQBuffer):
369
358
  kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta.id)
370
- HWInterface.munmap(mem.va_addr, mem.meta.mmapsize)
359
+ FileIOInterface.munmap(mem.va_addr, mem.meta.mmapsize)
371
360
 
372
361
  def _ensure_stack_size(self, sz):
373
362
  if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)