tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import ctypes
2
- from typing import Any, Optional, Tuple, Dict, List, cast
2
+ from typing import Any, cast
3
3
  import tinygrad.runtime.autogen.cuda as cuda
4
4
  from tinygrad.helpers import init_c_var, dedup
5
5
  from tinygrad.device import Buffer, Device
@@ -9,18 +9,18 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
9
  from tinygrad.engine.jit import MultiGraphRunner, GraphException
10
10
 
11
11
  class CUDAGraph(MultiGraphRunner):
12
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
12
+ def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
13
13
  super().__init__(jit_cache, input_rawbuffers, var_vals)
14
14
 
15
15
  # Check all jit items are compatible.
16
16
  if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
17
17
 
18
18
  self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
19
- self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
19
+ self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
20
20
 
21
21
  self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
22
22
 
23
- for j,ji in enumerate(self.jit_cache):
23
+ for j,ji in enumerate(jit_cache):
24
24
  if isinstance(ji.prg, CompiledRunner):
25
25
  global_size, local_size = ji.prg.p.launch_dims(var_vals)
26
26
 
@@ -29,7 +29,7 @@ class CUDAGraph(MultiGraphRunner):
29
29
  c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
30
30
 
31
31
  c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
32
- kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
32
+ kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
33
33
  check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
34
34
 
35
35
  if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
@@ -48,7 +48,7 @@ class CUDAGraph(MultiGraphRunner):
48
48
 
49
49
  self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
50
50
 
51
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
51
+ def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
52
52
  # Update rawbuffers in the c_args struct.
53
53
  for (j,i),input_idx in self.input_replace.items():
54
54
  if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
@@ -61,9 +61,8 @@ class CUDAGraph(MultiGraphRunner):
61
61
 
62
62
  # Update launch dims in the kern_params struct.
63
63
  for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
64
- prg = cast(CompiledRunner, self.jit_cache[j].prg)
65
- node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
66
- node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
64
+ node = self.updatable_nodes[j][1]
65
+ node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
67
66
 
68
67
  # Update graph nodes with the updated structs.
69
68
  for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
@@ -1,58 +1,77 @@
1
1
  import collections, time
2
- from typing import List, Any, Dict, cast, Optional, Tuple, Set
3
- from tinygrad.helpers import round_up, PROFILE, memsize_to_str
4
- from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
5
- from tinygrad.device import Buffer, BufferOptions, Compiled, Device
6
- from tinygrad.ops import Variable
2
+ from typing import Any, cast
3
+ from tinygrad.helpers import round_up, PROFILE
4
+ from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
5
+ from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
6
+ from tinygrad.dtype import dtypes
7
+ from tinygrad.ops import UOp, Variable
7
8
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
8
9
  from tinygrad.engine.jit import MultiGraphRunner
9
10
 
10
11
  class HCQGraph(MultiGraphRunner):
11
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
12
+ def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
12
13
  super().__init__(jit_cache, input_rawbuffers, var_vals)
13
14
  self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
14
15
 
16
+ # Replace input buffers with variables.
17
+ self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
18
+ self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
19
+
20
+ for (j,i), input_idx in self.input_replace.items():
21
+ x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
22
+ self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
23
+
15
24
  # Allocate kernel args.
16
- kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
17
- for ji in self.jit_cache:
25
+ kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
26
+ for ji in jit_cache:
18
27
  if not isinstance(ji.prg, CompiledRunner): continue
19
- kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
20
- self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
28
+ kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
29
+ self.kernargs_bufs: dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
21
30
 
22
31
  # Fill initial arguments.
23
- self.ji_args: Dict[int, HCQArgsState] = {}
32
+ self.ji_args: dict[int, HCQArgsState] = {}
24
33
 
25
- kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
26
- for j,ji in enumerate(self.jit_cache):
34
+ kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, base=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
35
+ for j,ji in enumerate(jit_cache):
27
36
  if not isinstance(ji.prg, CompiledRunner): continue
28
- kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
29
- self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
37
+
38
+ self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
30
39
 
31
40
  # Schedule Dependencies.
32
41
  # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
33
42
  # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
34
43
  # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
35
44
  # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
36
- self.ji_schedule: Dict[int, Tuple[HCQCompiled, HWCommandQueue, List, List, HCQSignal, Optional[int]]] = {}
45
+ self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
37
46
 
38
- self.comp_queues: Dict[HCQCompiled, HWComputeQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
39
- self.copy_queues: Dict[HCQCompiled, HWCopyQueue] = {} # lazy allocation
47
+ self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
48
+ self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
40
49
 
41
- self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
50
+ self.signals: dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
42
51
  self.kickoff_value: int = 0
52
+ self.kickoff_var = UOp.variable("kickoff_var", 0, 0xffffffff, dtype=dtypes.uint32)
43
53
 
44
- self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
45
- self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
54
+ # When profiling allocate 2 signals for each jit item to measure speed. The jth jit item have signals at 2*j and 2*j+1.
55
+ # TODO: This logic might allocate a few extra signals...
56
+ self.prof_signals: list[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
57
+ self.prog_graph_deps: list[list[int]] = []
58
+ self.prof_graph_entries: list[ProfileGraphEntry] = []
46
59
 
47
- last_j: Dict[HWCommandQueue, Optional[int]] = collections.defaultdict(lambda: None)
48
- queue_access: Dict[HWCommandQueue, Dict[HWCommandQueue, Optional[int]]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
49
- dev_access: Dict[HWCommandQueue, Set[HCQCompiled]] = collections.defaultdict(set)
60
+ last_j: dict[HWQueue, int|None] = collections.defaultdict(lambda: None)
61
+ queue_access: dict[HWQueue, dict[HWQueue, int|None]] = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
62
+ dev_access: dict[HWQueue, set[HCQCompiled]] = collections.defaultdict(set)
50
63
 
51
64
  for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
52
65
 
53
- for j,ji in enumerate(self.jit_cache):
54
- enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
55
- enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
66
+ for j,ji in enumerate(jit_cache):
67
+ enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
68
+
69
+ if is_exec_prg:
70
+ enqueue_queue = self.comp_queues[enqueue_dev]
71
+ else:
72
+ assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
73
+ enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
74
+
56
75
  out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
57
76
 
58
77
  # Get dependencies based on input and output buffers.
@@ -70,13 +89,13 @@ class HCQGraph(MultiGraphRunner):
70
89
 
71
90
  # Ensure device is ready for use in current context: the graph has initialized the device and it's safe to operate on it within this graph.
72
91
  for dep_queue, _ in opt_deps: dev_access[enqueue_queue].update(dev_access[dep_queue])
73
- sync_signals = [(self.signals[d], self.kickoff_value) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
92
+ sync_signals = [(self.signals[d], self.kickoff_var) for b in ji.bufs if (d:=Device[cast(Buffer, b).device]) not in dev_access[enqueue_queue]]
74
93
  dev_access[enqueue_queue].update(cast(HCQCompiled, Device[cast(Buffer, b).device]) for b in ji.bufs)
75
94
 
76
95
  # Remove self-dependency for compute and copy queues.
77
96
  # For compute, in case of NV, optimize when only 1 same-queue dependency exists, since NV chains 2+ executions in this case,
78
97
  # eliminating dependency need.
79
- dname = enqueue_dev.dname.split(":", 1)[0]
98
+ dname = enqueue_dev.device.split(":", 1)[0]
80
99
  can_opt = dname in {"AMD", "QCOM"} or (dname == "NV" and len(sync_signals) == 0 and len(opt_deps) == 1 and id(opt_deps[0][0]) == id(out_signal))
81
100
  if can_opt or isinstance(ji.prg, BufferXfer): opt_deps = [x for x in opt_deps if id(x[0]) != id(out_signal)]
82
101
 
@@ -86,48 +105,52 @@ class HCQGraph(MultiGraphRunner):
86
105
 
87
106
  # Collect profile information if profiling is enabled.
88
107
  if PROFILE:
89
- prof_ji_desc = ji.prg.clprg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
90
-
91
- sig_st, sig_en = (j * 2, True), (j * 2 + 1, True)
92
- if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None: sig_st = (prev_ji * 2 + 1, False)
108
+ # When execution are chained, we can reuse the end timestamp from the previous command as the start timestamp for the current command.
109
+ sig_st = prev_ji * 2 + 1 if len(opt_deps) == 0 and (prev_ji:=last_j[enqueue_queue]) is not None else j * 2
93
110
 
94
- if is_exec_prg: prof_args = None
95
- else: prof_args = {"Size": memsize_to_str(ji.bufs[0].nbytes), "GB/S": lambda dur, b=ji.bufs[0].nbytes: f"{b/1e3/dur:.2f}"} # type: ignore
111
+ # Description based on the command.
112
+ prof_ji_desc = ji.prg._prg.name if is_exec_prg else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
96
113
 
97
- self.prof_records.append((sig_st, sig_en, enqueue_dev, prof_ji_desc, not is_exec_prg, [d - 1 for _, d in rdeps], prof_args))
114
+ self.prof_graph_entries.append(ProfileGraphEntry(enqueue_dev.device, prof_ji_desc, sig_st, j * 2 + 1, is_copy=not is_exec_prg))
115
+ self.prog_graph_deps.append([d - 1 for _, d in rdeps])
98
116
 
99
117
  last_j[enqueue_queue] = j
100
118
 
119
+ # Check which signals are used in the profile graph.
120
+ self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.prof_signals))]
121
+
101
122
  # Build hardware queues.
102
- self.op_cmd_idx: Dict[int, Tuple[Any, int]] = {}
103
- self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
104
- self.kickoff_wait_cmds: Dict[HWCommandQueue, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
123
+ self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
124
+
125
+ # Create variable timeline signals for each device.
126
+ timeline_sigaddrs = {dev: UOp.variable(f"timeline_sig_{dev.device_id}", 0, 0xffffffffffffffff, dtype=dtypes.uint64) for dev in self.devices}
127
+ self.virt_timeline_vals = {dev: UOp.variable(f"timeline_var_{dev.device_id}", 0, 0xffffffff, dtype=dtypes.uint32) for dev in self.devices}
128
+ self.virt_timeline_signals = {dev: dev.signal_t(base_addr=timeline_sigaddrs[dev], timeline_for_device=dev) for dev in self.devices}
105
129
 
106
130
  for dev in self.devices:
107
- self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
108
- .wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
131
+ self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
132
+ .wait(self.signals['CPU'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
109
133
 
110
- for j,ji in enumerate(self.jit_cache):
134
+ for j,ji in enumerate(jit_cache):
111
135
  enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
112
136
 
113
- for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
114
137
  for sig, val in sync_signals + deps: enqueue_queue.wait(sig, val)
115
138
 
116
139
  # Encode waits and start profile timestamp (if needed).
117
- if PROFILE and self.prof_records[j][0][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][0][0]])
140
+ if PROFILE and self.prof_signal_is_used[j * 2]: enqueue_queue.timestamp(self.prof_signals[j * 2])
118
141
 
119
142
  # Encode main commands based on ji type.
120
143
  if isinstance(ji.prg, CompiledRunner):
121
- cast(HWComputeQueue, enqueue_queue).exec(ji.prg.clprg, self.ji_args[j], *ji.prg.p.launch_dims(var_vals))
144
+ enqueue_queue.exec(ji.prg._prg, self.ji_args[j], tuple(ji.prg.p.global_size or (1,1,1)), tuple(ji.prg.p.local_size or (1,1,1)))
122
145
  elif isinstance(ji.prg, BufferXfer):
123
146
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
124
147
  cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
125
- cast(HWCopyQueue, enqueue_queue).copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes)
148
+
149
+ enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
126
150
  self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
127
- self.op_cmd_idx[j] = (enqueue_queue, len(enqueue_queue) - 1)
128
151
 
129
152
  # Encode finish profile timestamp (if needed).
130
- if PROFILE and self.prof_records[j][1][1]: enqueue_queue.timestamp(self.prof_signals[self.prof_records[j][1][0]])
153
+ if PROFILE and self.prof_signal_is_used[j * 2 + 1]: enqueue_queue.timestamp(self.prof_signals[j * 2 + 1])
131
154
 
132
155
  if signal_val is not None: enqueue_queue.signal(signal, signal_val)
133
156
 
@@ -135,13 +158,13 @@ class HCQGraph(MultiGraphRunner):
135
158
  for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
136
159
  if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
137
160
 
138
- self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value).bind(dev)
161
+ self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
139
162
  if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
140
163
 
141
- self.last_timeline: Dict[HCQCompiled, Tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
164
+ self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
142
165
  self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
143
166
 
144
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
167
+ def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
145
168
  # Wait and restore signals
146
169
  self.kickoff_value += 1
147
170
  for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
@@ -150,28 +173,16 @@ class HCQGraph(MultiGraphRunner):
150
173
 
151
174
  if PROFILE and self.kickoff_value > 1: self.collect_timestamps()
152
175
 
153
- # Update rawbuffers
154
- for (j,i),input_idx in self.input_replace.items():
155
- if j in self.ji_args: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
156
- else: self.op_cmd_idx[j][0].update_copy(self.op_cmd_idx[j][1], **{('dest' if i == 0 else 'src'): input_rawbuffers[input_idx]._buf.va_addr})
157
-
158
- # Update var_vals
159
- for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
176
+ hcq_var_vals = {self.kickoff_var: self.kickoff_value, **var_vals,
177
+ **{var: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
178
+ **{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
160
179
 
161
- # Update launch dims
162
- for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
163
- queue, cmd_ptr = self.op_cmd_idx[j]
164
- queue.update_exec(cmd_ptr, global_dims, local_dims)
180
+ # Update rawbuffers
181
+ for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
165
182
 
166
183
  for dev in self.devices:
167
- comp_queue, copy_queue, need_sig_upd = self.comp_queues[dev], self.copy_queues.get(dev, None), dev.timeline_signal != self.last_timeline[dev][0]
168
- comp_queue.update_wait(1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value - 1) \
169
- .update_wait(2, value=self.kickoff_value).update_signal(3, value=self.kickoff_value) \
170
- .update_signal(len(comp_queue)-1, dev.timeline_signal if need_sig_upd else None, dev.timeline_value).submit(dev)
171
-
172
- if copy_queue is not None:
173
- for cmd_idx in self.kickoff_wait_cmds[copy_queue]: copy_queue.update_wait(cmd_idx, value=self.kickoff_value)
174
- copy_queue.submit(dev)
184
+ self.comp_queues[dev].submit(dev, hcq_var_vals)
185
+ if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)
175
186
 
176
187
  self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
177
188
  dev.timeline_value += 1
@@ -183,18 +194,12 @@ class HCQGraph(MultiGraphRunner):
183
194
  return None
184
195
 
185
196
  def collect_timestamps(self):
186
- timestamps = [s.timestamp for s in self.prof_signals]
187
-
188
- for (st,_), (en,_), dev, desc, is_cp, deps, args in self.prof_records:
189
- dev.raw_prof_records += [(timestamps[st], timestamps[en], desc, is_cp, args)]
190
-
191
- for x in deps:
192
- (b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
193
- dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
197
+ # NOTE: Append to any device is fine...
198
+ self.devices[0].profile_events += [ProfileGraphEvent(self.prof_graph_entries, self.prog_graph_deps, [s.timestamp for s in self.prof_signals])]
194
199
 
195
200
  def __del__(self):
196
201
  for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
197
202
 
198
203
  if PROFILE and self.kickoff_value >= 1: self.collect_timestamps()
199
204
 
200
- for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferOptions(cpu_access=True))
205
+ for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
@@ -1,4 +1,4 @@
1
- from typing import List, Any, Dict, cast, Optional
1
+ from typing import Any, cast
2
2
  import ctypes
3
3
  from tinygrad.dtype import dtypes
4
4
  from tinygrad.helpers import dedup, getenv
@@ -7,7 +7,7 @@ from tinygrad.engine.realize import ExecItem, CompiledRunner
7
7
  from tinygrad.engine.jit import GraphRunner, GraphException
8
8
  from tinygrad.ops import Variable
9
9
  from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
10
- MTLResourceOptions, elapsed_time, objc_id
10
+ MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
11
11
 
12
12
  class MTLIndirectCommandType:
13
13
  MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
@@ -17,68 +17,64 @@ class MTLResourceUsage:
17
17
  MTLResourceUsageWrite = 0b10
18
18
 
19
19
  class MetalGraph(GraphRunner):
20
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
20
+ def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
21
21
  super().__init__(jit_cache, input_rawbuffers, var_vals)
22
22
  if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
23
23
 
24
24
  # create metal batch exec
25
- icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
26
- msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
27
- msg(icb_descriptor, "setInheritBuffers:", False)
28
- msg(icb_descriptor, "setInheritPipelineState:", False)
29
- msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
30
-
31
- self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
32
- icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
25
+ icb_descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"))
26
+ msg("setCommandTypes:")(icb_descriptor, MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
27
+ msg("setInheritBuffers:")(icb_descriptor, False)
28
+ msg("setInheritPipelineState:")(icb_descriptor, False)
29
+ msg("setMaxKernelBufferBindCount:")(icb_descriptor, 31)
30
+
31
+ self.icb = msg("newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", objc_instance)(self.dev.sysdevice,
32
+ icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache)
33
33
  if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
34
- icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
34
+ icb_label = bytes(msg("UTF8String", ctypes.c_char_p)(msg("description", objc_instance)(self.icb))).decode()
35
35
  self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
36
36
 
37
- if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
37
+ if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
38
38
  all_resources = [self.int_buf.buf] if len(self.vars) else []
39
39
  all_pipelines = []
40
- for j,ji in enumerate(self.jit_cache):
40
+ for j,ji in enumerate(jit_cache):
41
41
  prg: CompiledRunner = cast(CompiledRunner, ji.prg)
42
- icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
43
- all_pipelines.append(prg.clprg.pipeline_state)
44
- msg(icb_command, "setComputePipelineState:", prg.clprg.pipeline_state)
42
+ icb_command = msg("indirectComputeCommandAtIndex:", objc_instance)(self.icb, j)
43
+ all_pipelines.append(prg._prg.pipeline_state)
44
+ msg("setComputePipelineState:")(icb_command, prg._prg.pipeline_state)
45
45
  for i,b in enumerate(ji.bufs):
46
46
  if b is not None and b not in input_rawbuffers:
47
- msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
47
+ msg("setKernelBuffer:offset:atIndex:")(icb_command, b._buf.buf, b._buf.offset, i)
48
48
  all_resources.append(b._buf.buf)
49
- for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
49
+ for i,v in enumerate(prg.p.vars): msg("setKernelBuffer:offset:atIndex:")(icb_command, self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
50
50
 
51
51
  global_size, local_size = prg.p.launch_dims(var_vals)
52
- msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
53
- msg(icb_command, "setBarrier")
52
+ msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(icb_command, to_struct(*global_size), to_struct(*local_size))
53
+ msg("setBarrier")(icb_command)
54
54
 
55
55
  self.all_resources = dedup(all_resources)
56
56
  self.all_pipelines = dedup(all_pipelines)
57
57
  self.command_buffer: Any = None
58
- if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
59
- self.range = to_struct(0, len(self.jit_cache))
58
+ if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
59
+ self.range = to_struct(0, len(jit_cache))
60
60
 
61
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
61
+ def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], wait=False) -> float|None:
62
62
 
63
- if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
63
+ if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
64
64
  all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
65
65
 
66
66
  for (j,i),input_idx in self.input_replace.items():
67
- computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
68
- msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
69
- input_rawbuffers[input_idx]._buf.offset, i)
67
+ computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
68
+ msg("setKernelBuffer:offset:atIndex:")(computeCommand, input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
70
69
 
71
70
  for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
72
- prg = cast(CompiledRunner, self.jit_cache[j].prg)
73
- global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
74
- computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
75
- msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
76
- to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
71
+ computeCommand = msg("indirectComputeCommandAtIndex:", objc_id)(self.icb, j)
72
+ msg("concurrentDispatchThreadgroups:threadsPerThreadgroup:")(computeCommand, to_struct(*global_dims), to_struct(*local_dims))
77
73
  for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
78
74
 
79
- command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
80
- encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
81
- msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
75
+ command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
76
+ encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
77
+ msg("useResources:count:usage:")(encoder, (objc_id * len(all_resources))(*all_resources), len(all_resources),
82
78
  MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
83
79
 
84
80
  # NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
@@ -88,16 +84,17 @@ class MetalGraph(GraphRunner):
88
84
  # to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
89
85
  if getenv("FIX_METAL_ICB", self.needs_icb_fix):
90
86
  for ps in self.all_pipelines:
91
- msg(encoder, "setComputePipelineState:", ps)
92
- msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
87
+ msg("setComputePipelineState:")(encoder, ps)
88
+ msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(0,0,0), to_struct(0,0,0))
93
89
 
94
- msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
95
- msg(encoder, "endEncoding")
96
- msg(command_buffer, "commit")
90
+ msg("executeCommandsInBuffer:withRange:")(encoder, self.icb, self.range)
91
+ msg("endEncoding")(encoder)
92
+ msg("setLabel:")(command_buffer, to_ns_str(f"batched {len(self.jit_cache)}"))
93
+ msg("commit")(command_buffer)
97
94
  self.command_buffer = command_buffer
98
95
 
96
+ self.dev.mtl_buffers_in_flight.append(command_buffer)
99
97
  if wait:
100
98
  wait_check(command_buffer)
101
- return elapsed_time(command_buffer)
102
- self.device.mtl_buffers_in_flight.append(command_buffer)
99
+ return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
103
100
  return None