tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ import ctypes, collections, time, itertools
2
+ from typing import List, Any, Dict, cast, Optional, Tuple
3
+ from tinygrad.helpers import GraphException, init_c_var, round_up
4
+ from tinygrad.device import Buffer, BufferOptions
5
+ from tinygrad.device import Compiled, Device
6
+ from tinygrad.shape.symbolic import Variable
7
+ from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
8
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner
10
+ import tinygrad.runtime.autogen.hsa as hsa
11
+ from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
12
+
13
+ def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
14
+
15
+ class VirtAQLQueue(AQLQueue):
16
+ def __init__(self, device, sz):
17
+ self.device = device
18
+ self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
19
+ self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
20
+ self.packets_count = 0
21
+ self.available_packet_slots = sz
22
+ def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
23
+ def _submit_packet(self):
24
+ self.write_addr += AQL_PACKET_SIZE
25
+ self.packets_count += 1
26
+ self.available_packet_slots -= 1
27
+
28
+ class HSAGraph(MultiGraphRunner):
29
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
30
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
31
+
32
+ # Check all jit items are compatible.
33
+ compiled_devices = set()
34
+ for ji in self.jit_cache:
35
+ if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
36
+ elif isinstance(ji.prg, BufferXfer):
37
+ for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
38
+ else: raise GraphException
39
+ if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
40
+
41
+ self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
42
+
43
+ # Allocate kernel args.
44
+ kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
45
+ for ji in self.jit_cache:
46
+ if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
47
+ kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
48
+
49
+ # Fill initial arguments.
50
+ self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
51
+ for j,ji in enumerate(self.jit_cache):
52
+ if not isinstance(ji.prg, CompiledRunner): continue
53
+ self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
54
+ kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
55
+ for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
56
+ for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
57
+
58
+ # Build queues.
59
+ self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
60
+ self.packets = {}
61
+ self.transfers = []
62
+ self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
63
+ self.signals_to_reset: List[hsa.hsa_signal_t] = []
64
+ self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
65
+ self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
66
+
67
+ # Special packet to wait for the world.
68
+ self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
69
+ for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
70
+
71
+ for j,ji in enumerate(self.jit_cache):
72
+ if isinstance(ji.prg, CompiledRunner):
73
+ wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
74
+ for i in range(0, len(wait_signals), 5):
75
+ self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
76
+ self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
77
+
78
+ sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
79
+ self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
80
+ ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
81
+ if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
82
+ elif isinstance(ji.prg, BufferXfer):
83
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
84
+ dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
85
+ sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
86
+
87
+ wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
88
+ self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
89
+ (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
90
+ self.ji_to_transfer[j] = len(self.transfers) - 1
91
+ if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
92
+
93
+ # Wait for all active signals to finish the graph
94
+ wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
95
+ for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
96
+ for dev in self.signals_to_devices[v.handle]:
97
+ wait_signals_to_finish[dev].append(v)
98
+
99
+ self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
100
+ for dev in self.devices:
101
+ wait_signals = wait_signals_to_finish[dev]
102
+ for i in range(0, max(1, len(wait_signals)), 5):
103
+ self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
104
+
105
+ # Zero signals to allow graph to start and execute.
106
+ for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
107
+ hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
108
+
109
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
110
+ # Wait and restore signals
111
+ hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
112
+ for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
113
+ hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
114
+
115
+ # Update rawbuffers
116
+ for (j,i),input_idx in self.input_replace.items():
117
+ if j in self.ji_kargs_structs:
118
+ self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
119
+ else:
120
+ if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
121
+ elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
122
+
123
+ # Update var_vals
124
+ for j in self.jc_idx_with_updatable_var_vals:
125
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
126
+ self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
127
+
128
+ # Update launch dims
129
+ for j in self.jc_idx_with_updatable_launch_dims:
130
+ gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
131
+ self.packets[j].workgroup_size_x = lc[0]
132
+ self.packets[j].workgroup_size_y = lc[1]
133
+ self.packets[j].workgroup_size_z = lc[2]
134
+ self.packets[j].grid_size_x = gl[0] * lc[0]
135
+ self.packets[j].grid_size_y = gl[1] * lc[1]
136
+ self.packets[j].grid_size_z = gl[2] * lc[2]
137
+
138
+ for dev in self.devices:
139
+ dev.flush_hdp()
140
+ dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
141
+
142
+ for transfer_data in self.transfers:
143
+ check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
144
+
145
+ et = None
146
+ if wait:
147
+ st = time.perf_counter()
148
+ hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
149
+ et = time.perf_counter() - st
150
+
151
+ for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
152
+ return et
153
+
154
+ def alloc_signal(self, reset_on_start=False, wait_on=None):
155
+ sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
156
+ if reset_on_start: self.signals_to_reset.append(sync_signal)
157
+ if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
158
+ return sync_signal
159
+
160
+ def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
161
+ if isinstance(dep, hsa.hsa_signal_t): return dep
162
+ elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
163
+ if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
164
+ return packet.completion_signal
165
+ return None
166
+
167
+ def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
168
+ rdeps = self._access_resources(read, write, new_dependency)
169
+ wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
170
+ if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
171
+ return dedup_signals(wait_signals)
@@ -1,22 +1,17 @@
1
1
  from typing import List, Any, Dict, cast, Optional
2
- import numpy as np
3
2
  import Metal
4
3
  from tinygrad.dtype import dtypes
5
- from tinygrad.helpers import dedup, unwrap2
6
- from tinygrad.device import Buffer, CompiledASTRunner, update_stats
7
- from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
4
+ from tinygrad.helpers import dedup, unwrap2, GraphException
5
+ from tinygrad.device import Buffer
6
+ from tinygrad.engine.realize import ExecItem, CompiledRunner
7
+ from tinygrad.engine.jit import GraphRunner
8
8
  from tinygrad.shape.symbolic import Variable
9
- from tinygrad.runtime.ops_metal import MetalDevice
9
+ from tinygrad.runtime.ops_metal import wait_check
10
10
 
11
- class MetalGraph:
12
- def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
13
- if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
14
-
15
- self.jit_cache = jit_cache
16
- self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
17
- self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
18
- self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
19
- self.device: MetalDevice = device
11
+ class MetalGraph(GraphRunner):
12
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
13
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
14
+ if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
20
15
 
21
16
  # create metal batch exec
22
17
  icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
@@ -24,56 +19,57 @@ class MetalGraph:
24
19
  icb_descriptor.setInheritBuffers_(False)
25
20
  icb_descriptor.setInheritPipelineState_(False)
26
21
  icb_descriptor.setMaxKernelBufferBindCount_(31)
27
- self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0)) # noqa: E501
22
+ self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
23
+ Metal.MTLResourceOptions(0))
28
24
  if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
29
25
 
30
- if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
31
- all_resources = [self.int_buf] if len(var_vals) else []
26
+ if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
27
+ all_resources = [self.int_buf] if len(self.vars) else []
28
+
32
29
  for j,ji in enumerate(self.jit_cache):
33
- prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
30
+ prg: CompiledRunner = cast(CompiledRunner, ji.prg)
34
31
  descriptor = Metal.MTLComputePipelineDescriptor.new()
35
32
  descriptor.setComputeFunction_(prg.clprg.fxn)
36
33
  descriptor.setSupportIndirectCommandBuffers_(True)
37
- pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
38
34
  icb_command = self.icb.indirectComputeCommandAtIndex_(j)
39
- icb_command.setComputePipelineState_(pipeline_state)
40
- for i,b in enumerate(ji.rawbufs):
35
+ icb_command.setComputePipelineState_(unwrap2(
36
+ self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
37
+ for i,b in enumerate(ji.bufs):
41
38
  if b is not None:
42
39
  icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
43
40
  all_resources.append(b._buf)
44
- var_vals_keys = list(var_vals.keys())
45
- for i,v in enumerate(prg.vars):
46
- icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
41
+ for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
47
42
  if j not in self.jc_idx_with_updatable_launch_dims:
48
- global_size, local_size = prg.launch_dims(var_vals)
43
+ global_size, local_size = prg.p.launch_dims(var_vals)
49
44
  icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
50
45
  icb_command.setBarrier()
46
+
51
47
  self.all_resources = dedup(all_resources)
52
48
  self.command_buffer: Any = None
53
- if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32)
49
+ if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
50
+
51
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
52
+ if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
53
+ all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
54
54
 
55
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
56
- # NOTE: you at least can't update the ints if this is running
57
- if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
58
- all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
59
55
  for (j,i),input_idx in self.input_replace.items():
60
56
  self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
61
57
  for j in self.jc_idx_with_updatable_launch_dims:
62
- global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
63
- self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
64
- if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
58
+ global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
59
+ self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
60
+ Metal.MTLSize(*local_size))
61
+ for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
62
+
65
63
  command_buffer = self.device.mtl_queue.commandBuffer()
66
64
  encoder = command_buffer.computeCommandEncoder()
67
65
  encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
68
- encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
66
+ encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
69
67
  encoder.endEncoding()
70
68
  command_buffer.commit()
71
69
  self.command_buffer = command_buffer
70
+
72
71
  if wait:
73
- command_buffer.waitUntilCompleted()
74
- et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
75
- else:
76
- self.device.mtl_buffers_in_flight.append(command_buffer)
77
- et = None
78
- update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
79
- return et
72
+ wait_check(command_buffer)
73
+ return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
74
+ self.device.mtl_buffers_in_flight.append(command_buffer)
75
+ return None