tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/graph.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
import os, atexit, functools, contextlib
|
2
2
|
from collections import defaultdict
|
3
|
-
from typing import List, Any, DefaultDict
|
4
|
-
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps,
|
3
|
+
from typing import List, Any, DefaultDict
|
4
|
+
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps
|
5
5
|
from tinygrad.device import Device
|
6
|
-
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp
|
6
|
+
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp
|
8
8
|
from tinygrad.shape.symbolic import NumNode
|
9
9
|
from tinygrad.lazy import LazyBuffer
|
10
10
|
|
@@ -12,12 +12,11 @@ with contextlib.suppress(ImportError): import networkx as nx
|
|
12
12
|
|
13
13
|
# **** debugging and graphing ****
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
atexit.register(print_globalcounters)
|
15
|
+
def print_globalcounters():
|
16
|
+
if GlobalCounters.time_sum_s == 0: return
|
17
|
+
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
18
|
+
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
19
|
+
if DEBUG >= 2: atexit.register(print_globalcounters)
|
21
20
|
|
22
21
|
def save_graph(G, fn, opt=""):
|
23
22
|
print("saving", G, f"to {fn}.svg")
|
@@ -44,11 +43,10 @@ def realized_lazybuffer(lb:'LazyBuffer', num):
|
|
44
43
|
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
45
44
|
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
|
46
45
|
|
47
|
-
top_colors = {
|
48
|
-
TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
46
|
+
top_colors = {MetaOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", TernaryOps: "#c0c0c0"}
|
49
47
|
def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
50
48
|
init_graph()
|
51
|
-
if lb.base.realized is None and lb.base.op is
|
49
|
+
if lb.base.realized is None and lb.base.op is MetaOps.CONST: return
|
52
50
|
if lb.base != lb:
|
53
51
|
offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
|
54
52
|
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
@@ -59,14 +57,14 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|
59
57
|
label_append = []
|
60
58
|
for idx,x in enumerate(lb.srcs):
|
61
59
|
if nm(x) not in G.nodes: log_lazybuffer(x)
|
62
|
-
if x.base.realized is None and x.base.op is
|
60
|
+
if x.base.realized is None and x.base.op is MetaOps.CONST:
|
63
61
|
label_append.append(f"\nCONST{idx} {x.base.arg:g}")
|
64
62
|
else:
|
65
63
|
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
66
64
|
label = '"' + \
|
67
65
|
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
68
|
-
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {
|
69
|
-
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
|
66
|
+
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
|
67
|
+
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
|
70
68
|
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
71
69
|
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
72
70
|
else:
|
@@ -74,27 +72,16 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|
74
72
|
# realized but unseen?
|
75
73
|
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
76
74
|
|
77
|
-
|
78
|
-
cnt[0] += 1
|
79
|
-
src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
|
80
|
-
if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
|
81
|
-
if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
82
|
-
return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
|
83
|
-
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
84
|
-
lines = [f"━┳ {dag.op} {dag.arg}"]
|
85
|
-
childs = [_tree(c, cycles, cnt) for c in src]
|
86
|
-
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
87
|
-
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
88
|
-
|
89
|
-
def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
|
90
|
-
|
75
|
+
graph_uops_cnt = 0
|
91
76
|
def graph_uops(uops:List[UOp]):
|
77
|
+
global graph_uops_cnt
|
92
78
|
colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
|
93
|
-
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
|
79
|
+
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
|
94
80
|
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
95
81
|
G = nx.DiGraph()
|
96
82
|
for u in uops:
|
97
83
|
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
98
84
|
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
99
85
|
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
|
100
|
-
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
86
|
+
save_graph(G, f'{GRAPHPATH}.{graph_uops_cnt}.uops', '-Grankdir=LR')
|
87
|
+
graph_uops_cnt += 1
|
tinygrad/engine/jit.py
CHANGED
@@ -3,16 +3,18 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
|
|
3
3
|
import functools, itertools, collections
|
4
4
|
from tinygrad.tensor import Tensor
|
5
5
|
from tinygrad.lazy import LazyBuffer
|
6
|
-
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context,
|
6
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup
|
7
7
|
from tinygrad.device import Buffer, Compiled, Device
|
8
8
|
from tinygrad.dtype import DType
|
9
9
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
-
from tinygrad.shape.symbolic import Variable, sint
|
11
|
-
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
12
|
-
from tinygrad.engine.schedule import _internal_memory_planner
|
10
|
+
from tinygrad.shape.symbolic import Variable, sint, sym_infer
|
11
|
+
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
|
13
12
|
from tinygrad.nn.state import get_parameters
|
13
|
+
from dataclasses import dataclass
|
14
14
|
from weakref import WeakKeyDictionary
|
15
15
|
|
16
|
+
class GraphException(Exception): pass
|
17
|
+
|
16
18
|
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
|
17
19
|
# Split JIT cache into batches for faster graph execution.
|
18
20
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
@@ -30,10 +32,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
30
32
|
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
31
33
|
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
|
32
34
|
max_batch_size *= 2
|
33
|
-
if DEBUG >= 2: print(f"
|
35
|
+
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
34
36
|
except GraphException as e:
|
35
37
|
graphed_jit_cache.extend(current_batch)
|
36
|
-
if DEBUG >= 2: print(f"
|
38
|
+
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
37
39
|
current_batch = []
|
38
40
|
current_device = None
|
39
41
|
|
@@ -47,7 +49,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
47
49
|
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
|
48
50
|
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
49
51
|
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
|
50
|
-
type(ji_graph_dev)
|
52
|
+
type(ji_graph_dev) is type(current_device))
|
51
53
|
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
|
52
54
|
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
53
55
|
|
@@ -70,20 +72,40 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
|
|
70
72
|
class GraphRunner(Runner): # pylint: disable=abstract-method
|
71
73
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
72
74
|
self.jit_cache = jit_cache
|
73
|
-
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
74
|
-
self.
|
75
|
-
self.
|
75
|
+
self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
76
|
+
self.var_vals_replace:Dict[int, List[int]] = {}
|
77
|
+
self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
|
78
|
+
|
76
79
|
op_estimate: sint = 0
|
77
80
|
mem_estimate: sint = 0
|
81
|
+
lds_estimate: sint = 0
|
82
|
+
|
83
|
+
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
84
|
+
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and not all_int(d)] +
|
85
|
+
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and not all_int(d)])
|
86
|
+
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
87
|
+
|
78
88
|
for j,ji in enumerate(jit_cache):
|
79
89
|
op_estimate += ji.prg.op_estimate
|
80
90
|
mem_estimate += ji.prg.mem_estimate
|
91
|
+
lds_estimate += ji.prg.lds_estimate
|
81
92
|
if isinstance(ji.prg, CompiledRunner):
|
82
|
-
if ji.prg.p.vars: self.
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
93
|
+
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
94
|
+
|
95
|
+
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
96
|
+
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
97
|
+
|
98
|
+
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
99
|
+
op_estimate, mem_estimate, lds_estimate)
|
100
|
+
|
101
|
+
def updated_vars(self, var_vals):
|
102
|
+
vals = [var_vals[v] for v in self.vars]
|
103
|
+
for j, vidxs in self.var_vals_replace.items():
|
104
|
+
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
105
|
+
|
106
|
+
def updated_launch_dims(self, var_vals):
|
107
|
+
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
108
|
+
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
|
87
109
|
|
88
110
|
class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
89
111
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
@@ -106,93 +128,149 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
|
106
128
|
return list({id(x):x for x in wait_nodes}.values())
|
107
129
|
|
108
130
|
ReturnType = TypeVar('ReturnType')
|
109
|
-
|
131
|
+
@dataclass
|
132
|
+
class CapturedJit(Generic[ReturnType]):
|
133
|
+
ret: Any # includes the Tensors or any other returned object
|
134
|
+
jit_cache: List[ExecItem]
|
135
|
+
input_replace: Dict[Tuple[int, int], int]
|
136
|
+
extra_view_inputs: List[Tuple[int, int, str, int, DType]]
|
137
|
+
expected_names: List[Union[int, str]]
|
138
|
+
expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
|
139
|
+
|
140
|
+
def __reduce__(self):
|
141
|
+
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
142
|
+
self.expected_names, self.expected_st_vars_dtype_device)
|
143
|
+
|
144
|
+
def __post_init__(self):
|
145
|
+
self._jit_cache: List[ExecItem] = self.jit_cache
|
146
|
+
self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
|
147
|
+
self._graphed = False
|
148
|
+
self._clear_inputs()
|
149
|
+
|
150
|
+
def _clear_inputs(self):
|
151
|
+
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
152
|
+
|
153
|
+
# jit exec
|
154
|
+
def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
|
155
|
+
# assign inputs
|
156
|
+
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
157
|
+
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
158
|
+
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
159
|
+
|
160
|
+
# Condense the items into a graph executor.
|
161
|
+
if JIT < 2 and not self._graphed:
|
162
|
+
self._jit_cache = apply_graph_to_jit(self._jit_cache, input_buffers, var_vals)
|
163
|
+
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
164
|
+
self._graphed = True
|
165
|
+
|
166
|
+
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
167
|
+
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
168
|
+
self._clear_inputs()
|
169
|
+
return self.ret
|
170
|
+
|
171
|
+
def _prepare_jit_inputs(args, kwargs):
|
172
|
+
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
173
|
+
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
|
174
|
+
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
|
175
|
+
names: List[Union[int, str]] = [name for name,_ in input_tensors]
|
176
|
+
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
|
177
|
+
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
178
|
+
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
179
|
+
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
180
|
+
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
181
|
+
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
182
|
+
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
183
|
+
return input_buffers, var_vals, names, st_vars_dtype_device
|
184
|
+
|
110
185
|
class TinyJit(Generic[ReturnType]):
|
111
|
-
def __init__(self, fxn:Callable[..., ReturnType]):
|
186
|
+
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None):
|
187
|
+
assert fxn or captured, "need either a function or a CapturedJit"
|
112
188
|
self.fxn = fxn
|
113
|
-
self.
|
189
|
+
self.captured: Optional[CapturedJit] = captured
|
190
|
+
self.cnt: int = 2 if self.fxn is None else 0
|
114
191
|
|
115
192
|
def add_buffer(self, b:Buffer) -> Buffer:
|
116
|
-
if found:=self.
|
193
|
+
if found:=self._buffer_replace.get(b, None): return found
|
117
194
|
if b.is_allocated() or b.lb_refcount > 0: return b
|
118
195
|
if b._base is not None:
|
119
|
-
self.
|
196
|
+
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
120
197
|
else:
|
121
|
-
self.
|
198
|
+
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
122
199
|
return ret
|
123
200
|
|
124
201
|
def add(self, ei:ExecItem):
|
125
|
-
self.
|
202
|
+
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
126
203
|
|
127
204
|
def reset(self):
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
self.
|
131
|
-
|
132
|
-
|
205
|
+
assert self.fxn is not None, "can't reset without function"
|
206
|
+
self.cnt = 0
|
207
|
+
self.captured = None
|
208
|
+
|
209
|
+
def __reduce__(self):
|
210
|
+
assert self.captured is not None, "can't pickle an uncaptured JIT"
|
211
|
+
return self.__class__, (None, self.captured)
|
212
|
+
|
213
|
+
# keep legacy code working
|
214
|
+
@property
|
215
|
+
def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
216
|
+
@property
|
217
|
+
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
133
218
|
|
134
219
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
135
220
|
|
136
221
|
def __call__(self, *args, **kwargs) -> ReturnType:
|
137
|
-
|
138
|
-
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
|
139
|
-
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
|
140
|
-
names: List[Union[int, str]] = [name for name,_ in input_tensors]
|
141
|
-
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
|
142
|
-
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
143
|
-
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
144
|
-
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
145
|
-
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
146
|
-
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
147
|
-
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
222
|
+
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
|
148
223
|
if not JIT or self.cnt == 0:
|
149
|
-
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
|
150
224
|
# jit ignore
|
151
|
-
|
152
|
-
|
153
|
-
|
225
|
+
assert self.fxn is not None
|
226
|
+
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
227
|
+
ret = self.fxn(*args, **kwargs)
|
228
|
+
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
154
229
|
elif self.cnt == 1:
|
155
230
|
# jit capture
|
156
|
-
self.
|
157
|
-
|
231
|
+
assert self.fxn is not None
|
232
|
+
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
233
|
+
self._jit_cache: List[ExecItem] = []
|
234
|
+
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
158
235
|
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
159
236
|
capturing.append(self)
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
237
|
+
try:
|
238
|
+
ret = self.fxn(*args, **kwargs)
|
239
|
+
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
240
|
+
except Exception as e: raise e
|
241
|
+
finally: capturing.clear()
|
242
|
+
jit_cache = self._jit_cache
|
243
|
+
del self._buffer_replace, self._jit_cache
|
244
|
+
assert len(jit_cache), "didn't JIT anything!"
|
245
|
+
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
166
246
|
|
167
247
|
# track inputs that are views of buffers
|
168
|
-
|
248
|
+
# TODO: eventually expected_buffers should live in ExecItem
|
249
|
+
extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
|
250
|
+
for item in jit_cache:
|
169
251
|
for b in item.bufs:
|
170
252
|
if b is not None and b._base is not None and b._base in input_buffers:
|
171
253
|
input_buffers.append(b)
|
172
|
-
|
254
|
+
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
173
255
|
|
174
256
|
# memory planning (optional)
|
175
|
-
|
176
|
-
|
257
|
+
# Exclude buffers involved in transfer ops to preserve parallelism.
|
258
|
+
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
259
|
+
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
260
|
+
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
177
261
|
|
178
|
-
|
179
|
-
if
|
262
|
+
input_replace = get_input_replace(jit_cache, input_buffers)
|
263
|
+
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
180
264
|
|
181
|
-
|
182
|
-
|
265
|
+
# set this for next run
|
266
|
+
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
183
267
|
elif self.cnt >= 2:
|
184
268
|
# jit exec
|
185
|
-
assert self.
|
186
|
-
assert self.
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
|
191
|
-
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
192
|
-
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
193
|
-
|
194
|
-
# clear jit inputs
|
195
|
-
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
269
|
+
assert self.captured is not None
|
270
|
+
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
|
271
|
+
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
272
|
+
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
273
|
+
ret = self.captured(input_buffers, var_vals)
|
196
274
|
|
197
275
|
self.cnt += 1
|
198
|
-
return
|
276
|
+
return ret
|