tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
1
  from typing import List, Dict, cast
2
2
  import ctypes
3
- from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
4
- from tinygrad.engine.jit import GraphRunner
3
+ from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
4
+ from tinygrad.engine.jit import GraphRunner, GraphException
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.engine.realize import ExecItem, CompiledRunner
7
- from tinygrad.shape.symbolic import Variable
7
+ from tinygrad.ops import Variable
8
8
  from tinygrad.runtime.ops_clang import ClangProgram
9
9
  from tinygrad.renderer.cstyle import ClangRenderer
10
10
  render_dtype = ClangRenderer().render_dtype
@@ -1,12 +1,12 @@
1
1
  import ctypes
2
2
  from typing import Any, Optional, Tuple, Dict, List, cast
3
3
  import tinygrad.runtime.autogen.cuda as cuda
4
- from tinygrad.helpers import init_c_var, GraphException, dedup
4
+ from tinygrad.helpers import init_c_var, dedup
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
- from tinygrad.shape.symbolic import Variable
7
+ from tinygrad.ops import Variable
8
8
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
- from tinygrad.engine.jit import MultiGraphRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner, GraphException
10
10
 
11
11
  class CUDAGraph(MultiGraphRunner):
12
12
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@@ -25,21 +25,20 @@ class CUDAGraph(MultiGraphRunner):
25
25
  global_size, local_size = ji.prg.p.launch_dims(var_vals)
26
26
 
27
27
  new_node = cuda.CUgraphNode()
28
- deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
29
- [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
28
+ deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
30
29
  c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
31
30
 
32
31
  c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
33
32
  kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
34
33
  check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
35
34
 
36
- if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
35
+ if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
37
36
  self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
38
37
  elif isinstance(ji.prg, BufferXfer):
39
38
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
40
39
  src_dev = cast(CUDADevice, Device[src.device])
41
40
  node_from = cuda.CUgraphNode()
42
- deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
41
+ deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
43
42
  c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
44
43
  cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
45
44
  dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
@@ -58,13 +57,13 @@ class CUDAGraph(MultiGraphRunner):
58
57
  elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
59
58
 
60
59
  # Update var_vals in the c_args struct.
61
- for j in self.jc_idx_with_updatable_var_vals:
62
- for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
63
- setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
60
+ for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)
64
61
 
65
62
  # Update launch dims in the kern_params struct.
66
- for j in self.jc_idx_with_updatable_launch_dims:
67
- self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
63
+ for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
64
+ prg = cast(CompiledRunner, self.jit_cache[j].prg)
65
+ node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
66
+ node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
68
67
 
69
68
  # Update graph nodes with the updated structs.
70
69
  for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
@@ -76,6 +75,3 @@ class CUDAGraph(MultiGraphRunner):
76
75
  def __del__(self):
77
76
  if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
78
77
  if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
79
-
80
- def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
81
- node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
@@ -1,187 +1,200 @@
1
- import collections, array, time
1
+ import collections, time
2
2
  from typing import List, Any, Dict, cast, Optional, Tuple, Set
3
- from tinygrad.helpers import round_up, to_mv, PROFILE
3
+ from tinygrad.helpers import round_up, PROFILE, memsize_to_str
4
+ from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
4
5
  from tinygrad.device import Buffer, BufferOptions, Compiled, Device
5
- from tinygrad.shape.symbolic import Variable
6
+ from tinygrad.ops import Variable
6
7
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
7
8
  from tinygrad.engine.jit import MultiGraphRunner
8
9
 
9
10
  class HCQGraph(MultiGraphRunner):
10
11
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
11
12
  super().__init__(jit_cache, input_rawbuffers, var_vals)
12
- self.devices = list(set(cast(Any, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
13
+ self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
13
14
 
14
15
  # Allocate kernel args.
15
16
  kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
16
17
  for ji in self.jit_cache:
17
18
  if not isinstance(ji.prg, CompiledRunner): continue
18
19
  kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
19
- self.kernargs_bufs: Dict[Compiled, Any] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
20
- kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
20
+ self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
21
21
 
22
22
  # Fill initial arguments.
23
- self.kargs_addrs: Dict[int, int] = {}
24
- self.ji_args_bufs: Dict[int, memoryview] = {}
25
- self.ji_args_vars: Dict[int, memoryview] = {}
23
+ self.ji_args: Dict[int, HCQArgsState] = {}
24
+
25
+ kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
26
26
  for j,ji in enumerate(self.jit_cache):
27
27
  if not isinstance(ji.prg, CompiledRunner): continue
28
- self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
29
- kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
30
-
31
- self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset, len(ji.bufs) * 8).cast('Q')
32
- self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
33
- for i in range(len(ji.bufs)): self.ji_args_bufs[j][i] = cast(Buffer, ji.bufs[i])._buf.va_addr
34
- for i in range(len(ji.prg.p.vars)): self.ji_args_vars[j][i] = var_vals[ji.prg.p.vars[i]]
35
-
36
- # NV needs constbuffer to be set
37
- if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
28
+ kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
29
+ self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
38
30
 
39
31
  # Schedule Dependencies.
40
32
  # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
41
33
  # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
42
34
  # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
43
35
  # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
44
- self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
45
- self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
36
+ self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
37
+
38
+ self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
39
+ self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
46
40
 
47
- self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
48
- self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
49
- self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
50
- self.kickoff_value = 0
41
+ self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
42
+ self.kickoff_value: int = 0
51
43
 
52
- self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
53
- for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
44
+ self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
45
+ self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
54
46
 
55
- self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
56
- self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
47
+ last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
48
+ queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
49
+ dev_access: Dict[HWCommandQueue, Set[HCQCompiled]] = collections.defaultdict(set)
50
+
51
+ for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
57
52
 
58
53
  for j,ji in enumerate(self.jit_cache):
59
- enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
60
- enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
61
- out_signal = self.signals[enqueue_queue]
62
- writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
63
- deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
54
+ enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
55
+ enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
56
+ out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
64
57
 
65
- if isinstance(ji.prg, CompiledRunner):
66
- # Update signal on compute kernel to depend on the previous kernel.
67
- if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
58
+ # Get dependencies based on input and output buffers.
59
+ rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore
60
+
61
+ # Update dependencies to include previous kernel in queue. This is required for timeline signals.
62
+ opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])
63
+
64
+ # Optimize dependencies by removing redundant ones. Remove waiting for the value of the queue which is known to be already
65
+ # synced with the current queue.
66
+ for dep_queue, dep_val in sorted(deps, key=lambda x: x[1], reverse=True):
67
+ if (qa:=queue_access[enqueue_queue][dep_queue]) is None or qa < dep_val:
68
+ opt_deps.append((self.signals[dep_queue], dep_val))
69
+ queue_access[enqueue_queue][dep_queue] = dep_val
68
70
 
69
- # Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
70
- if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
71
- deps = [x for x in deps if id(x[0]) != id(out_signal)]
72
- elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
71
+ # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
72
+ for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
73
+ sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
74
+ dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
73
75
 
74
- # Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
75
- for sig, val in deps:
76
- if id(sig) in [id(x) for x in self.signals.values()]:
77
- self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
76
+ # Remove self-dependency for compute and copy queues.
77
+ # For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
78
+ # eliminating dependency need.
79
+ dname = enqueue_dev.dname.split(":", 1)[0]
80
+ can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
81
+ if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
78
82
 
79
- prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
80
- prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
81
- self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
82
- self.last_ji[enqueue_queue] = j
83
+ # Enable necessary signals in the schedule by setting the signal value.
84
+ for sig, val in opt_deps: self.ji_schedule[val - 1] = self.ji_schedule[val - 1][:5] + (val,)
85
+ self.ji_schedule[j] = (enqueue_dev, enqueue_queue, sync_signals, opt_deps[::-1], out_signal, None if is_exec_prg else (j + 1))
86
+
87
+ # Collect profile information if profiling is enabled.
88
+ if PROFILE:
89
+ prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
90
+
91
+ sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
92
+ if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
93
+
94
+ if is_exec_prg: prof_args = None
95
+ else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
96
+
97
+ self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
98
+
99
+ last_j[enqueue_queue] = j
83
100
 
84
101
  # Build hardware queues.
85
- self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
86
- self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
87
- self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
102
+ self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
103
+ self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
104
+ self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
88
105
 
89
106
  for dev in self.devices:
90
107
  self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
91
- .wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
108
+ .wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
92
109
 
93
110
  for j,ji in enumerate(self.jit_cache):
94
- deps, signal_value, prof_info = self.signal_sched[j]
95
- enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
111
+ enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
112
+
113
+ for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
114
+ for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
96
115
 
97
116
  # Encode waits and start profile timestamp (if needed).
98
- for sig, val in deps:
99
- enqueue_queue.wait(sig, val)
100
- if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
101
- if prof_info: enqueue_queue.timestamp(prof_info[0])
117
+ if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
102
118
 
103
119
  # Encode main commands based on ji type.
104
120
  if isinstance(ji.prg, CompiledRunner):
105
- enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
106
- signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
107
- self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
121
+ cast(HWComputeQueue, enqueue_queue).exec(ji.prg.clprg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
108
122
  elif isinstance(ji.prg, BufferXfer):
109
123
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
110
- Device[src.device]._gpu_map(dest._buf) #type: ignore
111
- enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
112
- self.copy_to_devs[Device[dest.device]].add(Device[src.device])
124
+ cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
125
+ cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
126
+ self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
127
+ self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
113
128
 
114
129
  # Encode finish profile timestamp (if needed).
115
- if prof_info: enqueue_queue.timestamp(prof_info[1])
130
+ if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
131
+
132
+ if signal_val is not None: enqueue_queue.signal(signal, signal_val)
116
133
 
117
134
  for dev in self.devices:
118
135
  for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
119
- if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
120
- self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
136
+ if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
137
+
138
+ self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
139
+ if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
121
140
 
122
- self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
123
- if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
124
- if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
141
+ self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
142
+ self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
125
143
 
126
144
  def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
127
145
  # Wait and restore signals
128
146
  self.kickoff_value += 1
129
- for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
130
- for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
131
- for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
132
- self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
147
+ for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
148
+ for sig in self.queue_signals_to_reset: sig.value = 0
149
+ self.signals['CPU'].value = self.kickoff_value
133
150
 
134
- if PROFILE and self.kickoff_value > 1:
135
- for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
136
- dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
151
+ if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
137
152
 
138
153
  # Update rawbuffers
139
- for (j,i),input_idx in self.input_replace.items(): self.ji_args_bufs[j][i] = input_rawbuffers[input_idx]._buf.va_addr
154
+ for (j,i),input_idx in self.input_replace.items():
155
+ if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
156
+ else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
140
157
 
141
158
  # Update var_vals
142
- for j in self.jc_idx_with_updatable_var_vals:
143
- for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
159
+ for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
144
160
 
145
- for j in self.jc_idx_with_updatable_launch_dims:
146
- queue, cmd_ptr = self.exec_ptrs[j]
147
- queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
161
+ # Update launch dims
162
+ for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
163
+ queue, cmd_ptr = self.op_cmd_idx[j]
164
+ queue.update_exec(cmd_ptr, global_dims, local_dims)
148
165
 
149
166
  for dev in self.devices:
150
- self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
151
- .update_signal(3, value=self.kickoff_value) \
152
- .update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
167
+ comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
168
+ comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
169
+ .update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
170
+ .update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
153
171
 
154
- if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
155
- for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
156
- cp_queue.submit(dev)
172
+ if copy_queue is not None:
173
+ for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
174
+ copy_queue.submit(dev)
157
175
 
158
- self.graph_timeline[dev] = dev.timeline_value
176
+ self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
159
177
  dev.timeline_value += 1
160
178
 
161
179
  if wait:
162
180
  st = time.perf_counter()
163
- for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
181
+ for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
164
182
  return time.perf_counter() - st
165
183
  return None
166
184
 
167
- def access_resources(self, queue, read, write, new_val):
168
- deps = self._access_resources(read, write, (queue, new_val))
185
+ def collect_timestamps(self):
186
+ timestamps = [s.timestamp for s in self.prof_signals]
169
187
 
170
- sync_signals = []
171
- for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
172
- for buf in read+write:
173
- if buf.device not in self.save_devs[queue]:
174
- self.save_devs[queue].add(buf.device)
175
- sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
188
+ for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
189
+ dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
176
190
 
177
- return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
191
+ for x in deps:
192
+ (b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
193
+ dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
178
194
 
179
195
  def __del__(self):
180
- for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
196
+ for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
181
197
 
182
- # Graph is destructed. No need to keep signals any more, so return them as part of profiling.
183
- if PROFILE and self.kickoff_value > 1:
184
- for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
198
+ if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
185
199
 
186
- self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore
187
- for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
200
+ for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
@@ -1,12 +1,20 @@
1
1
  from typing import List, Any, Dict, cast, Optional
2
- import Metal
2
+ import ctypes
3
3
  from tinygrad.dtype import dtypes
4
- from tinygrad.helpers import dedup, unwrap2, GraphException
4
+ from tinygrad.helpers import dedup, getenv
5
5
  from tinygrad.device import Buffer
6
6
  from tinygrad.engine.realize import ExecItem, CompiledRunner
7
- from tinygrad.engine.jit import GraphRunner
8
- from tinygrad.shape.symbolic import Variable
9
- from tinygrad.runtime.ops_metal import wait_check
7
+ from tinygrad.engine.jit import GraphRunner, GraphException
8
+ from tinygrad.ops import Variable
9
+ from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
10
+ MTLResourceOptions, elapsed_time, objc_id
11
+
12
+ class MTLIndirectCommandType:
13
+ MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
14
+
15
+ class MTLResourceUsage:
16
+ MTLResourceUsageRead = 0b01
17
+ MTLResourceUsageWrite = 0b10
10
18
 
11
19
  class MetalGraph(GraphRunner):
12
20
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@@ -14,62 +22,82 @@ class MetalGraph(GraphRunner):
14
22
  if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
15
23
 
16
24
  # create metal batch exec
17
- icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
18
- icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
19
- icb_descriptor.setInheritBuffers_(False)
20
- icb_descriptor.setInheritPipelineState_(False)
21
- icb_descriptor.setMaxKernelBufferBindCount_(31)
22
- self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
23
- Metal.MTLResourceOptions(0))
24
- if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
25
+ icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
26
+ msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
27
+ msg(icb_descriptor, "setInheritBuffers:", False)
28
+ msg(icb_descriptor, "setInheritPipelineState:", False)
29
+ msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
25
30
 
26
- if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
27
- all_resources = [self.int_buf] if len(self.vars) else []
31
+ self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
32
+ icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
33
+ if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
34
+ icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
35
+ self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
28
36
 
37
+ if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
38
+ all_resources = [self.int_buf.buf] if len(self.vars) else []
39
+ all_pipelines = []
29
40
  for j,ji in enumerate(self.jit_cache):
30
41
  prg: CompiledRunner = cast(CompiledRunner, ji.prg)
31
- descriptor = Metal.MTLComputePipelineDescriptor.new()
32
- descriptor.setComputeFunction_(prg.clprg.fxn)
33
- descriptor.setSupportIndirectCommandBuffers_(True)
34
- icb_command = self.icb.indirectComputeCommandAtIndex_(j)
35
- icb_command.setComputePipelineState_(unwrap2(
36
- self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
42
+ icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
43
+ all_pipelines.append(prg.clprg.pipeline_state)
44
+ msg(icb_command, "setComputePipelineState:", prg.clprg.pipeline_state)
37
45
  for i,b in enumerate(ji.bufs):
38
- if b is not None:
39
- icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
40
- all_resources.append(b._buf)
41
- for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
42
- if j not in self.jc_idx_with_updatable_launch_dims:
43
- global_size, local_size = prg.p.launch_dims(var_vals)
44
- icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
45
- icb_command.setBarrier()
46
+ if b is not None and b not in input_rawbuffers:
47
+ msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
48
+ all_resources.append(b._buf.buf)
49
+ for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
50
+
51
+ global_size, local_size = prg.p.launch_dims(var_vals)
52
+ msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
53
+ msg(icb_command, "setBarrier")
46
54
 
47
55
  self.all_resources = dedup(all_resources)
56
+ self.all_pipelines = dedup(all_pipelines)
48
57
  self.command_buffer: Any = None
49
- if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
58
+ if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
59
+ self.range = to_struct(0, len(self.jit_cache))
50
60
 
51
61
  def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
62
+
52
63
  if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
53
- all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
64
+ all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
54
65
 
55
66
  for (j,i),input_idx in self.input_replace.items():
56
- self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
57
- for j in self.jc_idx_with_updatable_launch_dims:
58
- global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
59
- self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
60
- Metal.MTLSize(*local_size))
67
+ computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
68
+ msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
69
+ input_rawbuffers[input_idx]._buf.offset, i)
70
+
71
+ for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
72
+ prg = cast(CompiledRunner, self.jit_cache[j].prg)
73
+ global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
74
+ computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
75
+ msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
76
+ to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
61
77
  for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
62
78
 
63
- command_buffer = self.device.mtl_queue.commandBuffer()
64
- encoder = command_buffer.computeCommandEncoder()
65
- encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
66
- encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
67
- encoder.endEncoding()
68
- command_buffer.commit()
79
+ command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
80
+ encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
81
+ msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
82
+ MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
83
+
84
+ # NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
85
+ # this is a O(n) hack to get them used. what should work is:
86
+ #encoder.useResources_count_usage_(self.all_pipelines, len(self.all_pipelines), Metal.MTLResourceUsageRead)
87
+ # but it fails with "Invalid Resource (00000009:kIOGPUCommandBufferCallbackErrorInvalidResource)"
88
+ # to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
89
+ if getenv("FIX_METAL_ICB", self.needs_icb_fix):
90
+ for ps in self.all_pipelines:
91
+ msg(encoder, "setComputePipelineState:", ps)
92
+ msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
93
+
94
+ msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
95
+ msg(encoder, "endEncoding")
96
+ msg(command_buffer, "commit")
69
97
  self.command_buffer = command_buffer
70
98
 
71
99
  if wait:
72
100
  wait_check(command_buffer)
73
- return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
101
+ return elapsed_time(command_buffer)
74
102
  self.device.mtl_buffers_in_flight.append(command_buffer)
75
103
  return None