tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.cuda as cuda
4
4
  from tinygrad.helpers import init_c_var, dedup
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
- from tinygrad.ops import Variable
7
+ from tinygrad.uop.ops import Variable
8
8
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
9
  from tinygrad.engine.jit import MultiGraphRunner, GraphException
10
10
 
@@ -28,8 +28,8 @@ class CUDAGraph(MultiGraphRunner):
28
28
  deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
29
29
  c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
30
30
 
31
- c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
32
- kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
31
+ c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x, ji.fixedvars.get(x)) for x in ji.prg.p.vars])
32
+ kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
33
33
  check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
34
34
 
35
35
  if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
@@ -1,11 +1,11 @@
1
1
  import collections, time
2
2
  from typing import Any, cast
3
- from tinygrad.helpers import round_up, PROFILE
4
- from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
3
+ from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup
4
+ from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
5
5
  from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
6
6
  from tinygrad.dtype import dtypes
7
- from tinygrad.ops import UOp, Variable
8
- from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
7
+ from tinygrad.uop.ops import UOp, Variable
8
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
9
9
  from tinygrad.engine.jit import MultiGraphRunner
10
10
 
11
11
  class HCQGraph(MultiGraphRunner):
@@ -13,6 +13,9 @@ class HCQGraph(MultiGraphRunner):
13
13
  super().__init__(jit_cache, input_rawbuffers, var_vals)
14
14
  self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
15
15
 
16
+ # CPU Device is always last
17
+ self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
18
+
16
19
  # Replace input buffers with variables.
17
20
  self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
18
21
  self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
@@ -26,16 +29,17 @@ class HCQGraph(MultiGraphRunner):
26
29
  for ji in jit_cache:
27
30
  if not isinstance(ji.prg, CompiledRunner): continue
28
31
  kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
29
- self.kernargs_bufs: dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
32
+ self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
30
33
 
31
34
  # Fill initial arguments.
32
35
  self.ji_args: dict[int, HCQArgsState] = {}
33
36
 
34
- kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, base=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
37
+ kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
35
38
  for j,ji in enumerate(jit_cache):
36
39
  if not isinstance(ji.prg, CompiledRunner): continue
37
40
 
38
- self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
41
+ argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
42
+ self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, argsbuf)
39
43
 
40
44
  # Schedule Dependencies.
41
45
  # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@@ -47,14 +51,15 @@ class HCQGraph(MultiGraphRunner):
47
51
  self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
48
52
  self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
49
53
 
50
- self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
54
+ self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
55
+ **{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
51
56
  self.kickoff_value: int = 0
52
57
  self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
53
58
 
54
59
  # When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
55
60
  # TODO: This logic might allocate a few extra signals...
56
- self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
57
- self.prog_graph_deps: list[list[int]] = []
61
+ self.prof_signals: list[HCQSignal] = []
62
+ self.prof_graph_deps: list[list[int]] = []
58
63
  self.prof_graph_entries: list[ProfileGraphEntry] = []
59
64
 
60
65
  last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
@@ -63,8 +68,18 @@ class HCQGraph(MultiGraphRunner):
63
68
 
64
69
  for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
65
70
 
71
+ self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
72
+ self.fixedvars: dict[HCQCompiled, dict[Variable, int]] = {}
73
+
66
74
  for j,ji in enumerate(jit_cache):
67
- enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
75
+ if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
76
+ else:
77
+ # For copy ops prioritize enqeueuing on the dest device, so reverse the buffers.
78
+ for b in cast(list[Buffer], ji.bufs[::-1]):
79
+ if (enqueue_dev:=cast(HCQCompiled, Device[b.device])).hw_copy_queue_t is not None: break
80
+
81
+ # set any fixedvars on the device
82
+ self.fixedvars[enqueue_dev] = merge_dicts([self.fixedvars.get(enqueue_dev, {}), ji.fixedvars])
68
83
 
69
84
  if is_exec_prg:
70
85
  enqueue_queue = self.comp_queues[enqueue_dev]
@@ -72,7 +87,7 @@ class HCQGraph(MultiGraphRunner):
72
87
  assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
73
88
  enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
74
89
 
75
- out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
90
+ out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
76
91
 
77
92
  # Get dependencies based on input and output buffers.
78
93
  rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
@@ -86,9 +101,9 @@ class HCQGraph(MultiGraphRunner):
86
101
  if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
87
102
  opt_deps.append((self.signals[dep_queue], dep_val))
88
103
  queue_access[enqueue_queue][dep_queue] = dep_val
104
+ dev_access[enqueue_queue].update(dev_access[dep_queue])
89
105
 
90
106
  # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
91
- for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
92
107
  sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
93
108
  dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
94
109
 
@@ -112,28 +127,31 @@ class HCQGraph(MultiGraphRunner):
112
127
  prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
113
128
 
114
129
  self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
115
- self.prog_graph_deps.append([d - 1 for _, d in rdeps])
130
+ self.prof_graph_deps.append([d - 1 for _, d in rdeps])
116
131
 
117
132
  last_j[enqueue_queue] = j
118
133
 
119
134
  # Check which signals are used in the profile graph.
120
- self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
135
+ self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
121
136
 
122
137
  # Build hardware queues.
123
138
  self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
124
139
 
125
140
  # Create variable timeline signals for each device.
126
- timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
127
- self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
128
- self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
141
+ timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{self.dev_name(dev)}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
142
+ self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{self.dev_name(dev)}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
143
+ self.virt_timeline_signals = {dev: dev.signal_t(HCQBuffer(timeline_sigaddrs[dev], 16), owner=dev, is_timeline=True) for dev in self.devices}
129
144
 
130
145
  for dev in self.devices:
131
146
  self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
132
- .wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
147
+ .wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
133
148
 
134
149
  for j,ji in enumerate(jit_cache):
135
150
  enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
136
151
 
152
+ # Lazy allocate signals
153
+ if PROFILE: self.prof_signals += [enqueue_dev.new_signal(value=0) for _ in range(2)]
154
+
137
155
  for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
138
156
 
139
157
  # Encode waits and start profile timestamp (if needed).
@@ -142,10 +160,11 @@ class HCQGraph(MultiGraphRunner):
142
160
  # Encode main commands based on ji type.
143
161
  if isinstance(ji.prg, CompiledRunner):
144
162
  enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
145
- elif isinstance(ji.prg, BufferXfer):
163
+ elif isinstance(ji.prg, (BufferXfer, BufferCopy)):
146
164
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
147
- cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
148
-
165
+ for bufid, src in enumerate(cast(list[Buffer], ji.bufs)):
166
+ if (inprep_idx:=self.input_replace.get((j, bufid))) is not None: self.input_replace_map[enqueue_dev].add(inprep_idx)
167
+ else: cast(HCQAllocator, enqueue_dev.allocator).map(self.hcq_bufs[j][bufid])
149
168
  enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
150
169
  self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
151
170
 
@@ -169,23 +188,25 @@ class HCQGraph(MultiGraphRunner):
169
188
  self.kickoff_value += 1
170
189
  for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
171
190
  for sig in self.queue_signals_to_reset: sig.value = 0
172
- self.signals['CPU'].value = self.kickoff_value
191
+ self.signals['KICK'].value = self.kickoff_value
192
+
193
+ for dev in self.devices:
194
+ for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
173
195
 
174
196
  if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
175
197
 
176
198
  hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
177
199
  **{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
178
- **{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
200
+ **{sig.base_buf.va_addr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
179
201
 
180
202
  # Update rawbuffers
181
203
  for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
182
204
 
183
205
  for dev in self.devices:
184
- self.comp_queues[dev].submit(dev, hcq_var_vals)
185
- if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
206
+ self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
207
+ if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
186
208
 
187
- self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
188
- dev.timeline_value += 1
209
+ self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
189
210
 
190
211
  if wait:
191
212
  st = time.perf_counter()
@@ -195,7 +216,9 @@ class HCQGraph(MultiGraphRunner):
195
216
 
196
217
  def collect_timestamps(self):
197
218
  # NOTE: Append to any device is fine...
198
- self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
219
+ self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prof_graph_deps, [s.timestamp for s in self.prof_signals])]
220
+
221
+ def dev_name(self, dev) -> str: return dev.device.replace(":", "_")
199
222
 
200
223
  def __del__(self):
201
224
  for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
@@ -203,3 +226,19 @@ class HCQGraph(MultiGraphRunner):
203
226
  if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
204
227
 
205
228
  for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
229
+
230
+ @staticmethod
231
+ def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
232
+ # Check if all devices are HCQ
233
+ all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
234
+ if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
235
+
236
+ # If all of devices are mapped into CPU address space, can use CPU inside the peer group.
237
+ cpu_support = all(isinstance(d.timeline_signal.base_buf.view, MMIOInterface) for d in all_devs)
238
+
239
+ # Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
240
+ if len(set(d.peer_group for d in all_devs if cpu_support and not d._is_cpu())) > 1: return False
241
+
242
+ # MOCKGPU is not supported, since it can't execute commands in parallel
243
+ copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
244
+ return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy
@@ -1,11 +1,11 @@
1
1
  from typing import Any, cast
2
- import ctypes
2
+ import ctypes, re, decimal
3
3
  from tinygrad.dtype import dtypes
4
- from tinygrad.helpers import dedup, getenv
5
- from tinygrad.device import Buffer
4
+ from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
5
+ from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
6
6
  from tinygrad.engine.realize import ExecItem, CompiledRunner
7
7
  from tinygrad.engine.jit import GraphRunner, GraphException
8
- from tinygrad.ops import Variable
8
+ from tinygrad.uop.ops import Variable
9
9
  from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
10
10
  MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
11
11
 
@@ -32,11 +32,13 @@ class MetalGraph(GraphRunner):
32
32
  icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
33
33
  if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
34
34
  icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
35
- self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
35
+ self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
36
36
 
37
- if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
38
- all_resources = [self.int_buf.buf] if len(self.vars) else []
39
- all_pipelines = []
37
+ self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
38
+ self.varlist = self.vars + list(self.fixedvars.keys())
39
+ if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
40
+
41
+ all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
40
42
  for j,ji in enumerate(jit_cache):
41
43
  prg: CompiledRunner = cast(CompiledRunner, ji.prg)
42
44
  icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
@@ -46,7 +48,7 @@ class MetalGraph(GraphRunner):
46
48
  if b is not None and b not in input_rawbuffers:
47
49
  msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
48
50
  all_resources.append(b._buf.buf)
49
- for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
51
+ for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.varlist.index(v)*4, len(ji.bufs)+i)
50
52
 
51
53
  global_size, local_size = prg.p.launch_dims(var_vals)
52
54
  msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
@@ -55,14 +57,16 @@ class MetalGraph(GraphRunner):
55
57
  self.all_resources = dedup(all_resources)
56
58
  self.all_pipelines = dedup(all_pipelines)
57
59
  self.command_buffer: Any = None
58
- if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
60
+ if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
61
+ for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
59
62
  self.range = to_struct(0, len(jit_cache))
60
63
 
61
64
  def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
62
-
63
65
  if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
64
- all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
66
+ # NOTE: old command buffer may not be inflight anymore
67
+ if self.command_buffer is not None and PROFILE: self.collect_timestamps()
65
68
 
69
+ all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
66
70
  for (j,i),input_idx in self.input_replace.items():
67
71
  computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
68
72
  msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
@@ -70,7 +74,7 @@ class MetalGraph(GraphRunner):
70
74
  for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
71
75
  computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
72
76
  msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
73
- for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
77
+ for var in self.vars: self.int_buf_view[self.varlist.index(var)] = var_vals[var]
74
78
 
75
79
  command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
76
80
  encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
@@ -98,3 +102,15 @@ class MetalGraph(GraphRunner):
98
102
  wait_check(command_buffer)
99
103
  return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
100
104
  return None
105
+
106
+ def collect_timestamps(self):
107
+ # create a graph event and evenly space each program
108
+ st, en = decimal.Decimal(cmdbuf_st_time(self.command_buffer)) * 1000000, decimal.Decimal(cmdbuf_en_time(self.command_buffer)) * 1000000
109
+ ents = [ProfileGraphEntry(self.device, cast(CompiledRunner, ji.prg)._prg.name, i, i+1, is_copy=False) for i,ji in enumerate(self.jit_cache)]
110
+ step = (en-st)/len(ents)
111
+ self.dev.profile_events += [ProfileGraphEvent(ents, [], [st+step*i for i in range(len(ents)+1)])]
112
+
113
+ def __del__(self):
114
+ if PROFILE and self.command_buffer is not None:
115
+ wait_check(self.command_buffer)
116
+ self.collect_timestamps()
@@ -0,0 +1,114 @@
1
+ import time, itertools
2
+ from tinygrad.uop.ops import Variable
3
+ from tinygrad.engine.jit import MultiGraphRunner
4
+ from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
5
+ from tinygrad.device import Device, Compiled, Buffer
6
+ from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
7
+ from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
8
+ from tinygrad.helpers import unwrap, flatten, dedup
9
+ from enum import Enum, auto
10
+ from dataclasses import replace
11
+ from collections import defaultdict
12
+ from typing import cast
13
+
14
+ class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
15
+
16
+ def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
17
+ def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
18
+ def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
19
+
20
+ class RemoteGraph(MultiGraphRunner):
21
+ def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[Variable, int]):
22
+ super().__init__(jit_cache, rawbufs, var_vals)
23
+ devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
24
+ c2d = {device.conn: device for device in devices}
25
+ self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
26
+
27
+ self.template: list[RemoteRequest] = []
28
+
29
+ stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
30
+ clobbered_buffers: set[Buffer] = set()
31
+ cur_staging_type: StagingType = StagingType.NONE
32
+
33
+ def _flush(new_staging_type:StagingType, force_break:bool=False):
34
+ nonlocal cur_staging_type
35
+ if cur_staging_type == new_staging_type and not force_break: return
36
+ # Pre-sync
37
+ if cur_staging_type == StagingType.TRANSFER:
38
+ for sdev,ddev in itertools.permutations(c2d.values(), 2):
39
+ self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
40
+ self.template.append(Wait(event, session=ddev.session))
41
+ # Flush
42
+ for dev in devices:
43
+ dk = dev_key(dev)
44
+ staging = stagings[dk]
45
+ if not staging: continue
46
+ match cur_staging_type:
47
+ case StagingType.GRAPH:
48
+ bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
49
+ dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
50
+ self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
51
+ case StagingType.TRANSFER:
52
+ st = cast(list[Transfer], staging)
53
+ for host in dedup(t.dsession.host for t in st):
54
+ sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
55
+ dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
56
+ self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
57
+ staging.clear()
58
+ # Post-sync
59
+ if cur_staging_type == StagingType.TRANSFER:
60
+ for sdev,ddev in itertools.permutations(c2d.values(), 2):
61
+ self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
62
+ self.template.append(Wait(event, session=ddev.session))
63
+ cur_staging_type = new_staging_type
64
+ clobbered_buffers.clear()
65
+
66
+ for ji in jit_cache:
67
+ match ji.prg:
68
+ case CompiledRunner():
69
+ _flush(StagingType.GRAPH)
70
+ gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
71
+ tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
72
+ tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
73
+ tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
74
+ stagings[dev_key(ji.prg.dev)].append(gi)
75
+ case BufferXfer():
76
+ dest, src = ji.bufs[0:2]
77
+ dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
78
+ assert dest is not None and src is not None, ji
79
+ ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
80
+ if dev_key(dest_dev) == dev_key(src_dev):
81
+ _flush(StagingType.GRAPH)
82
+ stagings[dev_key(src_dev)].append(ti)
83
+ elif dest_dev.conn == src_dev.conn:
84
+ _flush(StagingType.NONE)
85
+ self.template.append(ti)
86
+ else:
87
+ _flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
88
+ clobbered_buffers.add(dest)
89
+ stagings[dev_key(src_dev)].append(ti)
90
+ case _: raise NotImplementedError(ji.prg)
91
+ _flush(StagingType.NONE)
92
+ def __del__(self):
93
+ for req in self.template:
94
+ match req:
95
+ case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
96
+ def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
97
+ if wait: st = time.perf_counter()
98
+ rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
99
+ for req in self.template:
100
+ match req:
101
+ case GraphExec():
102
+ req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
103
+ case Transfer():
104
+ if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
105
+ if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
106
+ case BatchTransfer():
107
+ req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
108
+ case Event()|Wait():
109
+ pass # event number can be reused
110
+ case _: raise NotImplementedError(req)
111
+ RemoteConnection(unwrap(req.session).host).q(req)
112
+ if wait:
113
+ RemoteConnection(unwrap(req.session).host).batch_submit()
114
+ return time.perf_counter() - st