tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
import ctypes
|
2
|
+
import tinygrad.runtime.autogen.comgr as comgr
|
3
|
+
|
4
|
+
def check(status):
|
5
|
+
if status != 0:
|
6
|
+
comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
7
|
+
raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
|
8
|
+
|
9
|
+
def _get_comgr_data(data_set, data_type):
|
10
|
+
check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
|
11
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
|
12
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
|
13
|
+
check(comgr.amd_comgr_release_data(data_exec))
|
14
|
+
return bytes(dat)
|
15
|
+
|
16
|
+
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
|
17
|
+
def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
|
18
|
+
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
|
19
|
+
check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
|
20
|
+
check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
|
21
|
+
check(comgr.amd_comgr_action_info_set_logging(action_info, True))
|
22
|
+
|
23
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
|
24
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
|
25
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
|
26
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
|
27
|
+
|
28
|
+
check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
|
29
|
+
check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
|
30
|
+
|
31
|
+
if asm:
|
32
|
+
check(comgr.amd_comgr_set_data_name(data_src, b"<null>.s"))
|
33
|
+
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
34
|
+
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_ASSEMBLE_SOURCE_TO_RELOCATABLE, action_info, data_set_src, data_set_reloc)
|
35
|
+
if status != 0:
|
36
|
+
print(_get_comgr_data(data_set_reloc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
37
|
+
raise RuntimeError("assemble failed")
|
38
|
+
else:
|
39
|
+
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
40
|
+
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
41
|
+
# -include hiprtc_runtime.h was removed
|
42
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
|
43
|
+
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
44
|
+
if status != 0:
|
45
|
+
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
46
|
+
raise RuntimeError("compile failed")
|
47
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
|
48
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
|
49
|
+
|
50
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b""))
|
51
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
|
52
|
+
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
|
53
|
+
check(comgr.amd_comgr_release_data(data_src))
|
54
|
+
for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
|
55
|
+
check(comgr.amd_comgr_destroy_action_info(action_info))
|
56
|
+
return ret
|
File without changes
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from typing import List, Dict, cast
|
2
|
+
import ctypes
|
3
|
+
from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
|
4
|
+
from tinygrad.engine.jit import GraphRunner
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
+
from tinygrad.shape.symbolic import Variable
|
8
|
+
from tinygrad.runtime.ops_clang import ClangProgram
|
9
|
+
from tinygrad.renderer.cstyle import ClangRenderer
|
10
|
+
render_dtype = ClangRenderer().render_dtype
|
11
|
+
|
12
|
+
class ClangGraph(GraphRunner):
|
13
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
14
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
15
|
+
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
16
|
+
|
17
|
+
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
18
|
+
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
19
|
+
args += sorted([f"int {v.expr}" for v in var_vals])
|
20
|
+
code = ["void batched("+','.join(args)+") {"]
|
21
|
+
for ji in jit_cache:
|
22
|
+
args = []
|
23
|
+
for buf in ji.bufs:
|
24
|
+
assert buf is not None
|
25
|
+
if buf in input_rawbuffers:
|
26
|
+
args.append(f"arg{input_rawbuffers.index(buf)}")
|
27
|
+
else:
|
28
|
+
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
|
29
|
+
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
30
|
+
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
31
|
+
code.append("}")
|
32
|
+
if DEBUG >= 4: print("\n".join(code))
|
33
|
+
compiler = Device["CLANG"].compiler
|
34
|
+
assert compiler is not None
|
35
|
+
self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
36
|
+
|
37
|
+
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
|
38
|
+
return cpu_time_execution(
|
39
|
+
lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)
|
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -1,76 +1,81 @@
|
|
1
1
|
import ctypes
|
2
2
|
from typing import Any, Optional, Tuple, Dict, List, cast
|
3
|
-
import
|
4
|
-
from tinygrad.helpers import init_c_var,
|
5
|
-
from tinygrad.device import
|
6
|
-
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
3
|
+
import tinygrad.runtime.autogen.cuda as cuda
|
4
|
+
from tinygrad.helpers import init_c_var, GraphException, dedup
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
7
|
from tinygrad.shape.symbolic import Variable
|
8
|
-
from tinygrad.
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
9
10
|
|
10
|
-
class CUDAGraph:
|
11
|
-
def __init__(self, jit_cache: List[
|
12
|
-
|
11
|
+
class CUDAGraph(MultiGraphRunner):
|
12
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
13
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
13
14
|
|
14
|
-
|
15
|
-
|
16
|
-
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
17
|
-
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
18
|
-
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
19
|
-
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
20
|
-
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params)
|
15
|
+
# Check all jit items are compatible.
|
16
|
+
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
21
17
|
|
22
|
-
self.
|
23
|
-
|
18
|
+
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
|
19
|
+
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
20
|
+
|
21
|
+
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
24
22
|
|
25
|
-
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
26
23
|
for j,ji in enumerate(self.jit_cache):
|
27
|
-
|
24
|
+
if isinstance(ji.prg, CompiledRunner):
|
25
|
+
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
26
|
+
|
27
|
+
new_node = cuda.CUgraphNode()
|
28
|
+
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
|
29
|
+
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
|
30
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
28
31
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
32
|
+
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
33
|
+
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
34
|
+
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
33
35
|
|
34
|
-
|
35
|
-
|
36
|
+
if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
|
37
|
+
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
38
|
+
elif isinstance(ji.prg, BufferXfer):
|
39
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
40
|
+
src_dev = cast(CUDADevice, Device[src.device])
|
41
|
+
node_from = cuda.CUgraphNode()
|
42
|
+
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
|
43
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
44
|
+
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
45
|
+
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
46
|
+
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
47
|
+
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
48
|
+
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
36
49
|
|
37
|
-
self.instance =
|
50
|
+
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
38
51
|
|
39
|
-
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False
|
40
|
-
# Update rawbuffers in the
|
52
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
53
|
+
# Update rawbuffers in the c_args struct.
|
41
54
|
for (j,i),input_idx in self.input_replace.items():
|
42
|
-
setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
55
|
+
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
56
|
+
else:
|
57
|
+
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
|
58
|
+
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
43
59
|
|
44
|
-
# Update var_vals in the
|
45
|
-
for j in self.
|
46
|
-
for i,v in enumerate(cast(
|
47
|
-
setattr(self.updatable_nodes[j][2], f'
|
60
|
+
# Update var_vals in the c_args struct.
|
61
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
62
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
63
|
+
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
48
64
|
|
49
|
-
# Update launch dims in the
|
50
|
-
for j in self.
|
51
|
-
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(
|
65
|
+
# Update launch dims in the kern_params struct.
|
66
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
67
|
+
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
52
68
|
|
53
69
|
# Update graph nodes with the updated structs.
|
54
|
-
for node, c_node_params,
|
55
|
-
|
70
|
+
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
71
|
+
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
|
72
|
+
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
|
56
73
|
|
57
|
-
|
58
|
-
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
59
|
-
return et
|
74
|
+
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
60
75
|
|
61
76
|
def __del__(self):
|
62
|
-
check(cuda.cuGraphDestroy(self.graph))
|
63
|
-
check(cuda.cuGraphExecDestroy(self.instance))
|
64
|
-
|
65
|
-
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
66
|
-
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
67
|
-
def graph_instantiate(self, graph):
|
68
|
-
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
69
|
-
def graph_add_kernel_node(self, graph, c_deps, c_node_params):
|
70
|
-
return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501
|
71
|
-
def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
|
72
|
-
def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args))
|
73
|
-
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config):
|
74
|
-
return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
|
77
|
+
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
78
|
+
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
79
|
+
|
75
80
|
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
76
81
|
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
@@ -0,0 +1,187 @@
|
|
1
|
+
import collections, array, time
|
2
|
+
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
3
|
+
from tinygrad.helpers import round_up, to_mv, PROFILE
|
4
|
+
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
5
|
+
from tinygrad.shape.symbolic import Variable
|
6
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
8
|
+
|
9
|
+
class HCQGraph(MultiGraphRunner):
|
10
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
11
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
12
|
+
self.devices = list(set(cast(Any, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
13
|
+
|
14
|
+
# Allocate kernel args.
|
15
|
+
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
16
|
+
for ji in self.jit_cache:
|
17
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
18
|
+
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
19
|
+
self.kernargs_bufs: Dict[Compiled, Any] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
20
|
+
kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
21
|
+
|
22
|
+
# Fill initial arguments.
|
23
|
+
self.kargs_addrs: Dict[int, int] = {}
|
24
|
+
self.ji_args_bufs: Dict[int, memoryview] = {}
|
25
|
+
self.ji_args_vars: Dict[int, memoryview] = {}
|
26
|
+
for j,ji in enumerate(self.jit_cache):
|
27
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
28
|
+
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
|
29
|
+
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
30
|
+
|
31
|
+
self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset, len(ji.bufs) * 8).cast('Q')
|
32
|
+
self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
|
33
|
+
for i in range(len(ji.bufs)): self.ji_args_bufs[j][i] = cast(Buffer, ji.bufs[i])._buf.va_addr
|
34
|
+
for i in range(len(ji.prg.p.vars)): self.ji_args_vars[j][i] = var_vals[ji.prg.p.vars[i]]
|
35
|
+
|
36
|
+
# NV needs constbuffer to be set
|
37
|
+
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
38
|
+
|
39
|
+
# Schedule Dependencies.
|
40
|
+
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
41
|
+
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
42
|
+
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
43
|
+
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
44
|
+
self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
45
|
+
self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
|
46
|
+
|
47
|
+
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
|
48
|
+
self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
|
49
|
+
self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
|
50
|
+
self.kickoff_value = 0
|
51
|
+
|
52
|
+
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
53
|
+
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
|
54
|
+
|
55
|
+
self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
|
56
|
+
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
57
|
+
|
58
|
+
for j,ji in enumerate(self.jit_cache):
|
59
|
+
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
60
|
+
enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
|
61
|
+
out_signal = self.signals[enqueue_queue]
|
62
|
+
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
|
63
|
+
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
|
64
|
+
|
65
|
+
if isinstance(ji.prg, CompiledRunner):
|
66
|
+
# Update signal on compute kernel to depend on the previous kernel.
|
67
|
+
if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
|
68
|
+
|
69
|
+
# Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
|
70
|
+
if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
|
71
|
+
deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
72
|
+
elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
73
|
+
|
74
|
+
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
|
75
|
+
for sig, val in deps:
|
76
|
+
if id(sig) in [id(x) for x in self.signals.values()]:
|
77
|
+
self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
|
78
|
+
|
79
|
+
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
80
|
+
prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
|
81
|
+
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
|
82
|
+
self.last_ji[enqueue_queue] = j
|
83
|
+
|
84
|
+
# Build hardware queues.
|
85
|
+
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
86
|
+
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
87
|
+
self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
88
|
+
|
89
|
+
for dev in self.devices:
|
90
|
+
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
91
|
+
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
|
92
|
+
|
93
|
+
for j,ji in enumerate(self.jit_cache):
|
94
|
+
deps, signal_value, prof_info = self.signal_sched[j]
|
95
|
+
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
|
96
|
+
|
97
|
+
# Encode waits and start profile timestamp (if needed).
|
98
|
+
for sig, val in deps:
|
99
|
+
enqueue_queue.wait(sig, val)
|
100
|
+
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
|
101
|
+
if prof_info: enqueue_queue.timestamp(prof_info[0])
|
102
|
+
|
103
|
+
# Encode main commands based on ji type.
|
104
|
+
if isinstance(ji.prg, CompiledRunner):
|
105
|
+
enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
106
|
+
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
|
107
|
+
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
108
|
+
elif isinstance(ji.prg, BufferXfer):
|
109
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
110
|
+
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
111
|
+
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
|
112
|
+
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
113
|
+
|
114
|
+
# Encode finish profile timestamp (if needed).
|
115
|
+
if prof_info: enqueue_queue.timestamp(prof_info[1])
|
116
|
+
|
117
|
+
for dev in self.devices:
|
118
|
+
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
119
|
+
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
|
120
|
+
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
|
121
|
+
|
122
|
+
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
|
123
|
+
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
|
124
|
+
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
|
125
|
+
|
126
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
127
|
+
# Wait and restore signals
|
128
|
+
self.kickoff_value += 1
|
129
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
130
|
+
for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
|
131
|
+
for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
|
132
|
+
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
|
133
|
+
|
134
|
+
if PROFILE and self.kickoff_value > 1:
|
135
|
+
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
136
|
+
dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
|
137
|
+
|
138
|
+
# Update rawbuffers
|
139
|
+
for (j,i),input_idx in self.input_replace.items(): self.ji_args_bufs[j][i] = input_rawbuffers[input_idx]._buf.va_addr
|
140
|
+
|
141
|
+
# Update var_vals
|
142
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
143
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
|
144
|
+
|
145
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
146
|
+
queue, cmd_ptr = self.exec_ptrs[j]
|
147
|
+
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
148
|
+
|
149
|
+
for dev in self.devices:
|
150
|
+
self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
|
151
|
+
.update_signal(3, value=self.kickoff_value) \
|
152
|
+
.update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
|
153
|
+
|
154
|
+
if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
|
155
|
+
for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
156
|
+
cp_queue.submit(dev)
|
157
|
+
|
158
|
+
self.graph_timeline[dev] = dev.timeline_value
|
159
|
+
dev.timeline_value += 1
|
160
|
+
|
161
|
+
if wait:
|
162
|
+
st = time.perf_counter()
|
163
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
164
|
+
return time.perf_counter() - st
|
165
|
+
return None
|
166
|
+
|
167
|
+
def access_resources(self, queue, read, write, new_val):
|
168
|
+
deps = self._access_resources(read, write, (queue, new_val))
|
169
|
+
|
170
|
+
sync_signals = []
|
171
|
+
for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
|
172
|
+
for buf in read+write:
|
173
|
+
if buf.device not in self.save_devs[queue]:
|
174
|
+
self.save_devs[queue].add(buf.device)
|
175
|
+
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
|
176
|
+
|
177
|
+
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
|
178
|
+
|
179
|
+
def __del__(self):
|
180
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
181
|
+
|
182
|
+
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
|
183
|
+
if PROFILE and self.kickoff_value > 1:
|
184
|
+
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
|
185
|
+
|
186
|
+
self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore
|
187
|
+
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
tinygrad/runtime/graph/metal.py
CHANGED
@@ -1,22 +1,17 @@
|
|
1
1
|
from typing import List, Any, Dict, cast, Optional
|
2
|
-
import numpy as np
|
3
2
|
import Metal
|
4
3
|
from tinygrad.dtype import dtypes
|
5
|
-
from tinygrad.helpers import dedup, unwrap2
|
6
|
-
from tinygrad.device import Buffer
|
7
|
-
from tinygrad.
|
4
|
+
from tinygrad.helpers import dedup, unwrap2, GraphException
|
5
|
+
from tinygrad.device import Buffer
|
6
|
+
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
+
from tinygrad.engine.jit import GraphRunner
|
8
8
|
from tinygrad.shape.symbolic import Variable
|
9
|
-
from tinygrad.runtime.ops_metal import
|
9
|
+
from tinygrad.runtime.ops_metal import wait_check
|
10
10
|
|
11
|
-
class MetalGraph:
|
12
|
-
def __init__(self,
|
13
|
-
|
14
|
-
|
15
|
-
self.jit_cache = jit_cache
|
16
|
-
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
17
|
-
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
18
|
-
self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
19
|
-
self.device: MetalDevice = device
|
11
|
+
class MetalGraph(GraphRunner):
|
12
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
13
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
14
|
+
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
20
15
|
|
21
16
|
# create metal batch exec
|
22
17
|
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
@@ -24,56 +19,57 @@ class MetalGraph:
|
|
24
19
|
icb_descriptor.setInheritBuffers_(False)
|
25
20
|
icb_descriptor.setInheritPipelineState_(False)
|
26
21
|
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
27
|
-
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
22
|
+
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
23
|
+
Metal.MTLResourceOptions(0))
|
28
24
|
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
29
25
|
|
30
|
-
if len(
|
31
|
-
all_resources = [self.int_buf] if len(
|
26
|
+
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
27
|
+
all_resources = [self.int_buf] if len(self.vars) else []
|
28
|
+
|
32
29
|
for j,ji in enumerate(self.jit_cache):
|
33
|
-
prg:
|
30
|
+
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
34
31
|
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
35
32
|
descriptor.setComputeFunction_(prg.clprg.fxn)
|
36
33
|
descriptor.setSupportIndirectCommandBuffers_(True)
|
37
|
-
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
|
38
34
|
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
39
|
-
icb_command.setComputePipelineState_(
|
40
|
-
|
35
|
+
icb_command.setComputePipelineState_(unwrap2(
|
36
|
+
self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)))
|
37
|
+
for i,b in enumerate(ji.bufs):
|
41
38
|
if b is not None:
|
42
39
|
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
43
40
|
all_resources.append(b._buf)
|
44
|
-
|
45
|
-
for i,v in enumerate(prg.vars):
|
46
|
-
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
41
|
+
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
47
42
|
if j not in self.jc_idx_with_updatable_launch_dims:
|
48
|
-
global_size, local_size = prg.launch_dims(var_vals)
|
43
|
+
global_size, local_size = prg.p.launch_dims(var_vals)
|
49
44
|
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
50
45
|
icb_command.setBarrier()
|
46
|
+
|
51
47
|
self.all_resources = dedup(all_resources)
|
52
48
|
self.command_buffer: Any = None
|
53
|
-
if len(
|
49
|
+
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
50
|
+
|
51
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
52
|
+
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
53
|
+
all_resources = dedup(self.all_resources + [x._buf for x in input_rawbuffers])
|
54
54
|
|
55
|
-
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
56
|
-
# NOTE: you at least can't update the ints if this is running
|
57
|
-
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
58
|
-
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
59
55
|
for (j,i),input_idx in self.input_replace.items():
|
60
56
|
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
61
57
|
for j in self.jc_idx_with_updatable_launch_dims:
|
62
|
-
global_size, local_size = cast(
|
63
|
-
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
64
|
-
|
58
|
+
global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
59
|
+
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
60
|
+
Metal.MTLSize(*local_size))
|
61
|
+
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
62
|
+
|
65
63
|
command_buffer = self.device.mtl_queue.commandBuffer()
|
66
64
|
encoder = command_buffer.computeCommandEncoder()
|
67
65
|
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
68
|
-
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
|
66
|
+
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache)))
|
69
67
|
encoder.endEncoding()
|
70
68
|
command_buffer.commit()
|
71
69
|
self.command_buffer = command_buffer
|
70
|
+
|
72
71
|
if wait:
|
73
|
-
command_buffer
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
et = None
|
78
|
-
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
79
|
-
return et
|
72
|
+
wait_check(command_buffer)
|
73
|
+
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
74
|
+
self.device.mtl_buffers_in_flight.append(command_buffer)
|
75
|
+
return None
|