tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
import ctypes, collections, time, itertools
|
2
|
+
from typing import List, Any, Dict, cast, Optional, Tuple
|
3
|
+
from tinygrad.helpers import GraphException, init_c_var, round_up
|
4
|
+
from tinygrad.device import Buffer, BufferOptions
|
5
|
+
from tinygrad.device import Compiled, Device
|
6
|
+
from tinygrad.shape.symbolic import Variable
|
7
|
+
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
10
|
+
import tinygrad.runtime.autogen.hsa as hsa
|
11
|
+
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
12
|
+
|
13
|
+
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
14
|
+
|
15
|
+
class VirtAQLQueue(AQLQueue):
|
16
|
+
def __init__(self, device, sz):
|
17
|
+
self.device = device
|
18
|
+
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
19
|
+
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
20
|
+
self.packets_count = 0
|
21
|
+
self.available_packet_slots = sz
|
22
|
+
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
23
|
+
def _submit_packet(self):
|
24
|
+
self.write_addr += AQL_PACKET_SIZE
|
25
|
+
self.packets_count += 1
|
26
|
+
self.available_packet_slots -= 1
|
27
|
+
|
28
|
+
class HSAGraph(MultiGraphRunner):
|
29
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
30
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
31
|
+
|
32
|
+
# Check all jit items are compatible.
|
33
|
+
compiled_devices = set()
|
34
|
+
for ji in self.jit_cache:
|
35
|
+
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
36
|
+
elif isinstance(ji.prg, BufferXfer):
|
37
|
+
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
38
|
+
else: raise GraphException
|
39
|
+
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
40
|
+
|
41
|
+
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
42
|
+
|
43
|
+
# Allocate kernel args.
|
44
|
+
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
45
|
+
for ji in self.jit_cache:
|
46
|
+
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
47
|
+
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
48
|
+
|
49
|
+
# Fill initial arguments.
|
50
|
+
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
51
|
+
for j,ji in enumerate(self.jit_cache):
|
52
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
53
|
+
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
54
|
+
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
55
|
+
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
56
|
+
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
57
|
+
|
58
|
+
# Build queues.
|
59
|
+
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
60
|
+
self.packets = {}
|
61
|
+
self.transfers = []
|
62
|
+
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
63
|
+
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
64
|
+
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
65
|
+
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
66
|
+
|
67
|
+
# Special packet to wait for the world.
|
68
|
+
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
69
|
+
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
70
|
+
|
71
|
+
for j,ji in enumerate(self.jit_cache):
|
72
|
+
if isinstance(ji.prg, CompiledRunner):
|
73
|
+
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
74
|
+
for i in range(0, len(wait_signals), 5):
|
75
|
+
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
76
|
+
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
77
|
+
|
78
|
+
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
79
|
+
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
80
|
+
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
81
|
+
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
82
|
+
elif isinstance(ji.prg, BufferXfer):
|
83
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
84
|
+
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
85
|
+
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
86
|
+
|
87
|
+
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
88
|
+
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
89
|
+
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
90
|
+
self.ji_to_transfer[j] = len(self.transfers) - 1
|
91
|
+
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
92
|
+
|
93
|
+
# Wait for all active signals to finish the graph
|
94
|
+
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
95
|
+
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
96
|
+
for dev in self.signals_to_devices[v.handle]:
|
97
|
+
wait_signals_to_finish[dev].append(v)
|
98
|
+
|
99
|
+
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
100
|
+
for dev in self.devices:
|
101
|
+
wait_signals = wait_signals_to_finish[dev]
|
102
|
+
for i in range(0, max(1, len(wait_signals)), 5):
|
103
|
+
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
104
|
+
|
105
|
+
# Zero signals to allow graph to start and execute.
|
106
|
+
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
107
|
+
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
108
|
+
|
109
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
110
|
+
# Wait and restore signals
|
111
|
+
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
112
|
+
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
113
|
+
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
114
|
+
|
115
|
+
# Update rawbuffers
|
116
|
+
for (j,i),input_idx in self.input_replace.items():
|
117
|
+
if j in self.ji_kargs_structs:
|
118
|
+
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
119
|
+
else:
|
120
|
+
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
121
|
+
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
122
|
+
|
123
|
+
# Update var_vals
|
124
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
125
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
126
|
+
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
127
|
+
|
128
|
+
# Update launch dims
|
129
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
130
|
+
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
131
|
+
self.packets[j].workgroup_size_x = lc[0]
|
132
|
+
self.packets[j].workgroup_size_y = lc[1]
|
133
|
+
self.packets[j].workgroup_size_z = lc[2]
|
134
|
+
self.packets[j].grid_size_x = gl[0] * lc[0]
|
135
|
+
self.packets[j].grid_size_y = gl[1] * lc[1]
|
136
|
+
self.packets[j].grid_size_z = gl[2] * lc[2]
|
137
|
+
|
138
|
+
for dev in self.devices:
|
139
|
+
dev.flush_hdp()
|
140
|
+
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
141
|
+
|
142
|
+
for transfer_data in self.transfers:
|
143
|
+
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
144
|
+
|
145
|
+
et = None
|
146
|
+
if wait:
|
147
|
+
st = time.perf_counter()
|
148
|
+
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
149
|
+
et = time.perf_counter() - st
|
150
|
+
|
151
|
+
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
152
|
+
return et
|
153
|
+
|
154
|
+
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
155
|
+
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
156
|
+
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
157
|
+
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
158
|
+
return sync_signal
|
159
|
+
|
160
|
+
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
161
|
+
if isinstance(dep, hsa.hsa_signal_t): return dep
|
162
|
+
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
163
|
+
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
164
|
+
return packet.completion_signal
|
165
|
+
return None
|
166
|
+
|
167
|
+
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
|
168
|
+
rdeps = self._access_resources(read, write, new_dependency)
|
169
|
+
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
170
|
+
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
|
171
|
+
return dedup_signals(wait_signals)
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,22 +1,17 @@
|
|
1
1
|
from typing import List, Any, Dict, cast, Optional
|
2
|
-
import numpy as np
|
3
2
|
import Metal
|
4
3
|
from tinygrad.dtype import dtypes
|
5
|
-
from tinygrad.helpers import dedup, unwrap2
|
6
|
-
from tinygrad.device import Buffer
|
7
|
-
from tinygrad.
|
4
|
+
from tinygrad.helpers import dedup, unwrap2, GraphException
|
5
|
+
from tinygrad.device import Buffer
|
6
|
+
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
+
from tinygrad.engine.jit import GraphRunner
|
8
8
|
from tinygrad.shape.symbolic import Variable
|
9
|
-
from tinygrad.runtime.ops_metal import
|
9
|
+
from tinygrad.runtime.ops_metal import wait_check
|
10
10
|
|
11
|
-
class MetalGraph:
|
12
|
-
def __init__(self,
|
13
|
-
|
14
|
-
|
15
|
-
self.jit_cache = jit_cache
|
16
|
-
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
17
|
-
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
18
|
-
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
19
|
-
self.device: MetalDevice = device
|
11
|
+
class MetalGraph(GraphRunner):
|
12
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
13
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
14
|
+
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
20
15
|
|
21
16
|
# create metal batch exec
|
22
17
|
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
@@ -24,56 +19,57 @@ class MetalGraph:
|
|
24
19
|
icb_descriptor.setInheritBuffers_(False)
|
25
20
|
icb_descriptor.setInheritPipelineState_(False)
|
26
21
|
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
27
|
-
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
22
|
+
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
23
|
+
Metal.MTLResourceOptions(0))
|
28
24
|
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
29
25
|
|
30
|
-
if len(
|
31
|
-
all_resources = [self.int_buf] if len(
|
26
|
+
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
27
|
+
all_resources = [self.int_buf] if len(self.vars) else []
|
28
|
+
|
32
29
|
for j,ji in enumerate(self.jit_cache):
|
33
|
-
prg:
|
30
|
+
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
34
31
|
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
35
32
|
descriptor.setComputeFunction_(prg.clprg.fxn)
|
36
33
|
descriptor.setSupportIndirectCommandBuffers_(True)
|
37
|
-
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
|
38
34
|
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
39
|
-
icb_command.setComputePipelineState_(
|
40
|
-
|
35
|
+
icb_command.setComputePipelineState_(unwrap2(
|
36
|
+
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
|
37
|
+
for i,b in enumerate(ji.bufs):
|
41
38
|
if b is not None:
|
42
39
|
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
43
40
|
all_resources.append(b._buf)
|
44
|
-
|
45
|
-
for i,v in enumerate(prg.vars):
|
46
|
-
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
41
|
+
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
47
42
|
if j not in self.jc_idx_with_updatable_launch_dims:
|
48
|
-
global_size, local_size = prg.launch_dims(var_vals)
|
43
|
+
global_size, local_size = prg.p.launch_dims(var_vals)
|
49
44
|
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
50
45
|
icb_command.setBarrier()
|
46
|
+
|
51
47
|
self.all_resources = dedup(all_resources)
|
52
48
|
self.command_buffer: Any = None
|
53
|
-
if len(
|
49
|
+
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
50
|
+
|
51
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
52
|
+
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
53
|
+
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
|
54
54
|
|
55
|
-
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
56
|
-
# NOTE: you at least can't update the ints if this is running
|
57
|
-
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
58
|
-
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
59
55
|
for (j,i),input_idx in self.input_replace.items():
|
60
56
|
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
61
57
|
for j in self.jc_idx_with_updatable_launch_dims:
|
62
|
-
global_size, local_size = cast(
|
63
|
-
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
64
|
-
|
58
|
+
global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
59
|
+
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
60
|
+
Metal.MTLSize(*local_size))
|
61
|
+
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
62
|
+
|
65
63
|
command_buffer = self.device.mtl_queue.commandBuffer()
|
66
64
|
encoder = command_buffer.computeCommandEncoder()
|
67
65
|
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
68
|
-
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
66
|
+
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
|
69
67
|
encoder.endEncoding()
|
70
68
|
command_buffer.commit()
|
71
69
|
self.command_buffer = command_buffer
|
70
|
+
|
72
71
|
if wait:
|
73
|
-
command_buffer
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
et = None
|
78
|
-
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
79
|
-
return et
|
72
|
+
wait_check(command_buffer)
|
73
|
+
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
74
|
+
self.device.mtl_buffers_in_flight.append(command_buffer)
|
75
|
+
return None
|