tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/graph.py CHANGED
@@ -1,10 +1,10 @@
1
1
  import os, atexit, functools, contextlib
2
2
  from collections import defaultdict
3
- from typing import List, Any, DefaultDict, Union
4
- from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
3
+ from typing import List, Any, DefaultDict
4
+ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps
5
5
  from tinygrad.device import Device
6
- from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
7
- from tinygrad.codegen.uops import UOps, UOp, UPat
6
+ from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
7
+ from tinygrad.codegen.uops import UOps, UOp
8
8
  from tinygrad.shape.symbolic import NumNode
9
9
  from tinygrad.lazy import LazyBuffer
10
10
 
@@ -12,12 +12,11 @@ with contextlib.suppress(ImportError): import networkx as nx
12
12
 
13
13
  # **** debugging and graphing ****
14
14
 
15
- if DEBUG >= 2:
16
- def print_globalcounters():
17
- if GlobalCounters.time_sum_s == 0: return
18
- print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
19
- f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
20
- atexit.register(print_globalcounters)
15
+ def print_globalcounters():
16
+ if GlobalCounters.time_sum_s == 0: return
17
+ print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
18
+ f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
19
+ if DEBUG >= 2: atexit.register(print_globalcounters)
21
20
 
22
21
  def save_graph(G, fn, opt=""):
23
22
  print("saving", G, f"to {fn}.svg")
@@ -44,11 +43,10 @@ def realized_lazybuffer(lb:'LazyBuffer', num):
44
43
  G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
45
44
  G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
46
45
 
47
- top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
48
- TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
46
+ top_colors = {MetaOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0", TernaryOps: "#c0c0c0"}
49
47
  def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
50
48
  init_graph()
51
- if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
49
+ if lb.base.realized is None and lb.base.op is MetaOps.CONST: return
52
50
  if lb.base != lb:
53
51
  offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
54
52
  label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
@@ -59,14 +57,14 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
59
57
  label_append = []
60
58
  for idx,x in enumerate(lb.srcs):
61
59
  if nm(x) not in G.nodes: log_lazybuffer(x)
62
- if x.base.realized is None and x.base.op is LoadOps.CONST:
60
+ if x.base.realized is None and x.base.op is MetaOps.CONST:
63
61
  label_append.append(f"\nCONST{idx} {x.base.arg:g}")
64
62
  else:
65
63
  G.add_edge(nm(x), nm(lb), color='#a0a0a0')
66
64
  label = '"' + \
67
65
  (str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
68
- (f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
69
- (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
66
+ (f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
67
+ (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
70
68
  G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
71
69
  if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
72
70
  else:
@@ -74,27 +72,16 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
74
72
  # realized but unseen?
75
73
  G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
76
74
 
77
- def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
78
- cnt[0] += 1
79
- src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
80
- if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
81
- if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
82
- return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
83
- cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
84
- lines = [f"━┳ {dag.op} {dag.arg}"]
85
- childs = [_tree(c, cycles, cnt) for c in src]
86
- for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
87
- return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
88
-
89
- def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
90
-
75
+ graph_uops_cnt = 0
91
76
  def graph_uops(uops:List[UOp]):
77
+ global graph_uops_cnt
92
78
  colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
93
- UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
79
+ UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
94
80
  UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
95
81
  G = nx.DiGraph()
96
82
  for u in uops:
97
83
  if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
98
84
  G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
99
85
  for v in u.src: G.add_edge(uops.index(v), uops.index(u))
100
- save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
86
+ save_graph(G, f'{GRAPHPATH}.{graph_uops_cnt}.uops', '-Grankdir=LR')
87
+ graph_uops_cnt += 1
tinygrad/engine/jit.py CHANGED
@@ -3,16 +3,18 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
3
3
  import functools, itertools, collections
4
4
  from tinygrad.tensor import Tensor
5
5
  from tinygrad.lazy import LazyBuffer
6
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
6
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup
7
7
  from tinygrad.device import Buffer, Compiled, Device
8
8
  from tinygrad.dtype import DType
9
9
  from tinygrad.shape.shapetracker import ShapeTracker
10
- from tinygrad.shape.symbolic import Variable, sint
11
- from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
12
- from tinygrad.engine.schedule import _internal_memory_planner
10
+ from tinygrad.shape.symbolic import Variable, sint, sym_infer
11
+ from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
13
12
  from tinygrad.nn.state import get_parameters
13
+ from dataclasses import dataclass
14
14
  from weakref import WeakKeyDictionary
15
15
 
16
+ class GraphException(Exception): pass
17
+
16
18
  def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
17
19
  # Split JIT cache into batches for faster graph execution.
18
20
  # This allows the accelerator to run some batches while subsequent graphs are still being updated.
@@ -30,10 +32,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
30
32
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
31
33
  graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
32
34
  max_batch_size *= 2
33
- if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
35
+ if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
34
36
  except GraphException as e:
35
37
  graphed_jit_cache.extend(current_batch)
36
- if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
38
+ if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
37
39
  current_batch = []
38
40
  current_device = None
39
41
 
@@ -47,7 +49,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
47
49
  graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
48
50
  can_be_graphed = ji_graph_dev and ji_graph_dev.graph
49
51
  can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
50
- type(ji_graph_dev) == type(current_device))
52
+ type(ji_graph_dev) is type(current_device))
51
53
  can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
52
54
  if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
53
55
 
@@ -70,20 +72,40 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
70
72
  class GraphRunner(Runner): # pylint: disable=abstract-method
71
73
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
72
74
  self.jit_cache = jit_cache
73
- self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
74
- self.jc_idx_with_updatable_launch_dims = []
75
- self.jc_idx_with_updatable_var_vals = []
75
+ self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
76
+ self.var_vals_replace:Dict[int, List[int]] = {}
77
+ self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
78
+
76
79
  op_estimate: sint = 0
77
80
  mem_estimate: sint = 0
81
+ lds_estimate: sint = 0
82
+
83
+ self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
84
+ self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and not all_int(d)] +
85
+ [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and not all_int(d)])
86
+ def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
87
+
78
88
  for j,ji in enumerate(jit_cache):
79
89
  op_estimate += ji.prg.op_estimate
80
90
  mem_estimate += ji.prg.mem_estimate
91
+ lds_estimate += ji.prg.lds_estimate
81
92
  if isinstance(ji.prg, CompiledRunner):
82
- if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
83
- if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
84
- self.jc_idx_with_updatable_launch_dims.append(j)
85
- self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
86
- super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
93
+ if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
94
+
95
+ global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
96
+ if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
97
+
98
+ super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
99
+ op_estimate, mem_estimate, lds_estimate)
100
+
101
+ def updated_vars(self, var_vals):
102
+ vals = [var_vals[v] for v in self.vars]
103
+ for j, vidxs in self.var_vals_replace.items():
104
+ for i, v in enumerate(vidxs): yield j, i, vals[v]
105
+
106
+ def updated_launch_dims(self, var_vals):
107
+ dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
108
+ for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
87
109
 
88
110
  class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
89
111
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@@ -106,93 +128,149 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
106
128
  return list({id(x):x for x in wait_nodes}.values())
107
129
 
108
130
  ReturnType = TypeVar('ReturnType')
109
- IN_JIT = ContextVar('IN_JIT', 0)
131
+ @dataclass
132
+ class CapturedJit(Generic[ReturnType]):
133
+ ret: Any # includes the Tensors or any other returned object
134
+ jit_cache: List[ExecItem]
135
+ input_replace: Dict[Tuple[int, int], int]
136
+ extra_view_inputs: List[Tuple[int, int, str, int, DType]]
137
+ expected_names: List[Union[int, str]]
138
+ expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
139
+
140
+ def __reduce__(self):
141
+ return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
142
+ self.expected_names, self.expected_st_vars_dtype_device)
143
+
144
+ def __post_init__(self):
145
+ self._jit_cache: List[ExecItem] = self.jit_cache
146
+ self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
147
+ self._graphed = False
148
+ self._clear_inputs()
149
+
150
+ def _clear_inputs(self):
151
+ for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
152
+
153
+ # jit exec
154
+ def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
155
+ # assign inputs
156
+ for idx, offset, device, size, dtype in self.extra_view_inputs:
157
+ input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
158
+ for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
159
+
160
+ # Condense the items into a graph executor.
161
+ if JIT < 2 and not self._graphed:
162
+ self._jit_cache = apply_graph_to_jit(self._jit_cache, input_buffers, var_vals)
163
+ self._input_replace = get_input_replace(self._jit_cache, input_buffers)
164
+ self._graphed = True
165
+
166
+ if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
167
+ for ei in self._jit_cache: ei.run(var_vals, jit=True)
168
+ self._clear_inputs()
169
+ return self.ret
170
+
171
+ def _prepare_jit_inputs(args, kwargs):
172
+ input_tensors: List[Tuple[Union[int, str], Tensor]] = \
173
+ [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
174
+ if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
175
+ names: List[Union[int, str]] = [name for name,_ in input_tensors]
176
+ lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
177
+ st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
178
+ input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
179
+ assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
180
+ var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
181
+ [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
182
+ st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
183
+ return input_buffers, var_vals, names, st_vars_dtype_device
184
+
110
185
  class TinyJit(Generic[ReturnType]):
111
- def __init__(self, fxn:Callable[..., ReturnType]):
186
+ def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None):
187
+ assert fxn or captured, "need either a function or a CapturedJit"
112
188
  self.fxn = fxn
113
- self.reset()
189
+ self.captured: Optional[CapturedJit] = captured
190
+ self.cnt: int = 2 if self.fxn is None else 0
114
191
 
115
192
  def add_buffer(self, b:Buffer) -> Buffer:
116
- if found:=self.buffer_replace.get(b, None): return found
193
+ if found:=self._buffer_replace.get(b, None): return found
117
194
  if b.is_allocated() or b.lb_refcount > 0: return b
118
195
  if b._base is not None:
119
- self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
196
+ self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
120
197
  else:
121
- self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
198
+ self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
122
199
  return ret
123
200
 
124
201
  def add(self, ei:ExecItem):
125
- self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
202
+ self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
126
203
 
127
204
  def reset(self):
128
- self.jit_cache: List[ExecItem] = []
129
- self.input_replace: Dict[Tuple[int, int], int] = {}
130
- self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
131
- self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
132
- self.cnt: int = 0
205
+ assert self.fxn is not None, "can't reset without function"
206
+ self.cnt = 0
207
+ self.captured = None
208
+
209
+ def __reduce__(self):
210
+ assert self.captured is not None, "can't pickle an uncaptured JIT"
211
+ return self.__class__, (None, self.captured)
212
+
213
+ # keep legacy code working
214
+ @property
215
+ def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
216
+ @property
217
+ def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
133
218
 
134
219
  def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
135
220
 
136
221
  def __call__(self, *args, **kwargs) -> ReturnType:
137
- input_tensors: List[Tuple[Union[int, str], Tensor]] = \
138
- [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
139
- if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
140
- names: List[Union[int, str]] = [name for name,_ in input_tensors]
141
- lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
142
- st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
143
- input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
144
- assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
145
- var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
146
- [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
147
- st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
222
+ input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
148
223
  if not JIT or self.cnt == 0:
149
- if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
150
224
  # jit ignore
151
- with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
152
- self.ret = self.fxn(*args, **kwargs)
153
- if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
225
+ assert self.fxn is not None
226
+ with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
227
+ ret = self.fxn(*args, **kwargs)
228
+ if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
154
229
  elif self.cnt == 1:
155
230
  # jit capture
156
- self.expected_names: List[Union[int, str]] = names
157
- self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
231
+ assert self.fxn is not None
232
+ if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
233
+ self._jit_cache: List[ExecItem] = []
234
+ self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
158
235
  with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
159
236
  capturing.append(self)
160
- self.ret = self.fxn(*args, **kwargs)
161
- if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
162
- capturing.clear()
163
- del self.buffer_replace
164
- assert len(self.jit_cache), "didn't JIT anything!"
165
- if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
237
+ try:
238
+ ret = self.fxn(*args, **kwargs)
239
+ if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
240
+ except Exception as e: raise e
241
+ finally: capturing.clear()
242
+ jit_cache = self._jit_cache
243
+ del self._buffer_replace, self._jit_cache
244
+ assert len(jit_cache), "didn't JIT anything!"
245
+ if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
166
246
 
167
247
  # track inputs that are views of buffers
168
- for item in self.jit_cache:
248
+ # TODO: eventually expected_buffers should live in ExecItem
249
+ extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
250
+ for item in jit_cache:
169
251
  for b in item.bufs:
170
252
  if b is not None and b._base is not None and b._base in input_buffers:
171
253
  input_buffers.append(b)
172
- self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
254
+ extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
173
255
 
174
256
  # memory planning (optional)
175
- assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
176
- self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
257
+ # Exclude buffers involved in transfer ops to preserve parallelism.
258
+ noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
259
+ assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
260
+ jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
177
261
 
178
- # Condense the items into a graph executor.
179
- if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
262
+ input_replace = get_input_replace(jit_cache, input_buffers)
263
+ if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
180
264
 
181
- self.input_replace = get_input_replace(self.jit_cache, input_buffers)
182
- if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
265
+ # set this for next run
266
+ self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
183
267
  elif self.cnt >= 2:
184
268
  # jit exec
185
- assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
186
- assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
187
- f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
188
- for idx, offset, device, size, dtype in self.extra_view_inputs:
189
- input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
190
- for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
191
- if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
192
- for ei in self.jit_cache: ei.run(var_vals, jit=True)
193
-
194
- # clear jit inputs
195
- for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
269
+ assert self.captured is not None
270
+ assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
271
+ assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
272
+ f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
273
+ ret = self.captured(input_buffers, var_vals)
196
274
 
197
275
  self.cnt += 1
198
- return self.ret
276
+ return ret