tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/dtype.py CHANGED
@@ -1,39 +1,39 @@
1
- from typing import NamedTuple, Final, Optional, ClassVar, Set, Tuple, Dict
2
- import numpy as np # TODO: remove numpy
1
+ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
2
+ from dataclasses import dataclass
3
3
  import functools
4
+ from tinygrad.helpers import getenv
4
5
 
5
- # TODO: migrate this from NamedTuple -> dataclass
6
- class DType(NamedTuple):
6
+ ConstType = Union[float, int, bool]
7
+
8
+ @dataclass(frozen=True, order=True)
9
+ class DType:
7
10
  priority: int # this determines when things get upcasted
8
11
  itemsize: int
9
12
  name: str
10
- np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
11
- sz: int = 1
12
- def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
13
+ fmt: Optional[str]
14
+ count: int
15
+ def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
13
16
  def vec(self, sz:int):
14
- assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
15
- return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{sz}", None, sz)
16
- def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
17
+ assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
18
+ return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
19
+ def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
17
20
 
18
21
  # dependent typing?
22
+ @dataclass(frozen=True, repr=False)
19
23
  class ImageDType(DType):
20
- def __new__(cls, priority, itemsize, name, np, shape, base):
21
- return super().__new__(cls, priority, itemsize, name, np)
22
- def __init__(self, priority, itemsize, name, np, shape, base):
23
- self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
24
- self.base: DType = base
25
- super().__init__()
24
+ shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
25
+ base: DType
26
26
  def scalar(self): return self.base
27
27
  def vec(self, sz:int): return self.base.vec(sz)
28
28
  def __repr__(self): return f"dtypes.{self.name}({self.shape})"
29
- # TODO: fix this to not need these
30
- def __hash__(self): return hash((super().__hash__(), self.shape))
31
- def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
32
- def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
33
29
 
30
+ # @dataclass(frozen=True, init=False, repr=False, eq=False)
34
31
  class PtrDType(DType):
35
- def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
32
+ def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
36
33
  def __repr__(self): return f"ptr.{super().__repr__()}"
34
+ def __hash__(self): return super().__hash__()
35
+ def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
36
+ def __ne__(self, dt): return not (self == dt)
37
37
 
38
38
  class dtypes:
39
39
  @staticmethod
@@ -43,25 +43,31 @@ class dtypes:
43
43
  @staticmethod
44
44
  def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
45
45
  @staticmethod
46
- def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
47
- @staticmethod # NOTE: isinstance(True, int) is True in python
48
- def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
46
+ def from_py(x) -> DType:
47
+ if x.__class__ is float: return dtypes.default_float
48
+ if x.__class__ is int: return dtypes.default_int
49
+ if x.__class__ is bool: return dtypes.bool
50
+ # put this in the last is faster because there are more items than lists/tuples to check
51
+ if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
52
+ raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
53
+ @staticmethod
54
+ def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
49
55
  @staticmethod
50
56
  def fields() -> Dict[str, DType]: return DTYPES_DICT
51
- bool: Final[DType] = DType(0, 1, "bool", np.bool_)
52
- int8: Final[DType] = DType(1, 1, "char", np.int8)
53
- uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
54
- int16: Final[DType] = DType(3, 2, "short", np.int16)
55
- uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
56
- int32: Final[DType] = DType(5, 4, "int", np.int32)
57
- uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
58
- int64: Final[DType] = DType(7, 8, "long", np.int64)
59
- uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
60
- float16: Final[DType] = DType(9, 2, "half", np.float16)
57
+ bool: Final[DType] = DType(0, 1, "bool", '?', 1)
58
+ int8: Final[DType] = DType(1, 1, "char", 'b', 1)
59
+ uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
60
+ int16: Final[DType] = DType(3, 2, "short", 'h', 1)
61
+ uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
62
+ int32: Final[DType] = DType(5, 4, "int", 'i', 1)
63
+ uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
64
+ int64: Final[DType] = DType(7, 8, "long", 'l', 1)
65
+ uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
66
+ float16: Final[DType] = DType(9, 2, "half", 'e', 1)
61
67
  # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
62
- bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
63
- float32: Final[DType] = DType(11, 4, "float", np.float32)
64
- float64: Final[DType] = DType(12, 8, "double", np.float64)
68
+ bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
69
+ float32: Final[DType] = DType(11, 4, "float", 'f', 1)
70
+ float64: Final[DType] = DType(12, 8, "double", 'd', 1)
65
71
 
66
72
  # dtype aliases
67
73
  half = float16; float = float32; double = float64 # noqa: E702
@@ -70,13 +76,17 @@ class dtypes:
70
76
 
71
77
  # NOTE: these are image dtypes
72
78
  @staticmethod
73
- def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
79
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
74
80
  @staticmethod
75
- def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
81
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
76
82
 
77
83
  default_float: ClassVar[DType] = float32
78
84
  default_int: ClassVar[DType] = int32
79
85
 
86
+ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
87
+ dtypes.default_float = getattr(dtypes, env_default_float.lower())
88
+ assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
89
+
80
90
  # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
81
91
  # we don't support weak type and complex type
82
92
  promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
@@ -94,4 +104,10 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else
94
104
 
95
105
  # HACK: staticmethods are not callable in 3.8 so we have to compare the class
96
106
  DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
97
- INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
107
+ INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
108
+
109
+ def sum_acc_dtype(dt:DType):
110
+ # default acc dtype for sum
111
+ if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
112
+ if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
113
+ return least_upper_dtype(dt, dtypes.float)
File without changes
@@ -0,0 +1,100 @@
1
+ import os, atexit, functools, contextlib
2
+ from collections import defaultdict
3
+ from typing import List, Any, DefaultDict, Union
4
+ from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
5
+ from tinygrad.device import Device
6
+ from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
7
+ from tinygrad.codegen.uops import UOps, UOp, UPat
8
+ from tinygrad.shape.symbolic import NumNode
9
+ from tinygrad.lazy import LazyBuffer
10
+
11
+ with contextlib.suppress(ImportError): import networkx as nx
12
+
13
+ # **** debugging and graphing ****
14
+
15
+ if DEBUG >= 2:
16
+ def print_globalcounters():
17
+ if GlobalCounters.time_sum_s == 0: return
18
+ print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
19
+ f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
20
+ atexit.register(print_globalcounters)
21
+
22
+ def save_graph(G, fn, opt=""):
23
+ print("saving", G, f"to {fn}.svg")
24
+ nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
25
+ os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
26
+
27
+ G:Any = None
28
+ def init_graph():
29
+ global G
30
+ if G is not None: return
31
+ G = nx.DiGraph()
32
+ atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
33
+
34
+ counts: DefaultDict[type, int] = defaultdict(int)
35
+ def nm(x):
36
+ if not hasattr(x, 'node_id'):
37
+ setattr(x, 'node_id', counts[type(x)])
38
+ counts[type(x)] += 1
39
+ return x.node_id
40
+
41
+ def realized_lazybuffer(lb:'LazyBuffer', num):
42
+ init_graph()
43
+ G.nodes[nm(lb)]['style'] = '"filled,bold"'
44
+ G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
45
+ G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
46
+
47
+ top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
48
+ TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
49
+ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
50
+ init_graph()
51
+ if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
52
+ if lb.base != lb:
53
+ offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
54
+ label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
55
+ G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
56
+ G.add_edge(nm(lb.base), nm(lb), color='#00000060')
57
+ lb = lb.base
58
+ if lb.realized is None:
59
+ label_append = []
60
+ for idx,x in enumerate(lb.srcs):
61
+ if nm(x) not in G.nodes: log_lazybuffer(x)
62
+ if x.base.realized is None and x.base.op is LoadOps.CONST:
63
+ label_append.append(f"\nCONST{idx} {x.base.arg:g}")
64
+ else:
65
+ G.add_edge(nm(x), nm(lb), color='#a0a0a0')
66
+ label = '"' + \
67
+ (str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
68
+ (f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
69
+ (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + '"'
70
+ G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
71
+ if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
72
+ else:
73
+ if nm(lb) not in G.nodes:
74
+ # realized but unseen?
75
+ G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
76
+
77
+ def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
78
+ cnt[0] += 1
79
+ src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
80
+ if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
81
+ if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
82
+ return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
83
+ cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
84
+ lines = [f"━┳ {dag.op} {dag.arg}"]
85
+ childs = [_tree(c, cycles, cnt) for c in src]
86
+ for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
87
+ return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
88
+
89
+ def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
90
+
91
+ def graph_uops(uops:List[UOp]):
92
+ colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
93
+ UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0",
94
+ UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
95
+ G = nx.DiGraph()
96
+ for u in uops:
97
+ if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
98
+ G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
99
+ for v in u.src: G.add_edge(uops.index(v), uops.index(u))
100
+ save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
tinygrad/engine/jit.py ADDED
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
3
+ import functools, itertools, collections
4
+ from tinygrad.tensor import Tensor
5
+ from tinygrad.lazy import LazyBuffer
6
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
7
+ from tinygrad.device import Buffer, Compiled, Device
8
+ from tinygrad.dtype import DType
9
+ from tinygrad.shape.shapetracker import ShapeTracker
10
+ from tinygrad.shape.symbolic import Variable, sint
11
+ from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
12
+ from tinygrad.engine.schedule import _internal_memory_planner
13
+ from tinygrad.nn.state import get_parameters
14
+ from weakref import WeakKeyDictionary
15
+
16
+ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
17
+ # Split JIT cache into batches for faster graph execution.
18
+ # This allows the accelerator to run some batches while subsequent graphs are still being updated.
19
+ max_batch_size = getenv("JIT_BATCH_SIZE", 32)
20
+ graphed_jit_cache: List[ExecItem] = []
21
+ current_batch: List[ExecItem] = []
22
+ current_device: Optional[Compiled] = None
23
+
24
+ def flush_batch():
25
+ nonlocal current_batch, current_device, max_batch_size
26
+ try:
27
+ if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
28
+ graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
29
+ # clear jit inputs to allow their memory to be freed/reused
30
+ for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
31
+ graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
32
+ max_batch_size *= 2
33
+ if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
34
+ except GraphException as e:
35
+ graphed_jit_cache.extend(current_batch)
36
+ if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
37
+ current_batch = []
38
+ current_device = None
39
+
40
+ for ji in jit_cache:
41
+ if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
42
+ ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
43
+ if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
44
+ elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
45
+ ji_graph_dev = Device[ji.bufs[0].device]
46
+
47
+ graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
48
+ can_be_graphed = ji_graph_dev and ji_graph_dev.graph
49
+ can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
50
+ type(ji_graph_dev) == type(current_device))
51
+ can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
52
+ if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
53
+
54
+ if can_be_graphed: current_batch.append(ji)
55
+ else: graphed_jit_cache.append(ji)
56
+
57
+ current_device = ji_graph_dev
58
+
59
+ if len(current_batch) > 0: flush_batch()
60
+ return graphed_jit_cache
61
+
62
+ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
63
+ input_replace: Dict[Tuple[int, int], int] = {}
64
+ for j,ji in enumerate(jit_cache):
65
+ for i,a in enumerate(ji.bufs):
66
+ if a in input_rawbuffers:
67
+ input_replace[(j,i)] = input_rawbuffers.index(a)
68
+ return input_replace
69
+
70
+ class GraphRunner(Runner): # pylint: disable=abstract-method
71
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
72
+ self.jit_cache = jit_cache
73
+ self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
74
+ self.jc_idx_with_updatable_launch_dims = []
75
+ self.jc_idx_with_updatable_var_vals = []
76
+ op_estimate: sint = 0
77
+ mem_estimate: sint = 0
78
+ for j,ji in enumerate(jit_cache):
79
+ op_estimate += ji.prg.op_estimate
80
+ mem_estimate += ji.prg.mem_estimate
81
+ if isinstance(ji.prg, CompiledRunner):
82
+ if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
83
+ if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
84
+ self.jc_idx_with_updatable_launch_dims.append(j)
85
+ self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
86
+ super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
87
+
88
+ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
89
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
90
+ self.w_dependency_map: Dict[Any, Any] = {}
91
+ self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
92
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
93
+
94
+ def _access_resources(self, read, write, new_dependency:Any):
95
+ # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
96
+ # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
97
+ wait_nodes = []
98
+
99
+ for rawbuf in read + write:
100
+ if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
101
+ for rawbuf in write:
102
+ if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
103
+
104
+ for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
105
+ for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
106
+ return list({id(x):x for x in wait_nodes}.values())
107
+
108
+ ReturnType = TypeVar('ReturnType')
109
+ IN_JIT = ContextVar('IN_JIT', 0)
110
+ class TinyJit(Generic[ReturnType]):
111
+ def __init__(self, fxn:Callable[..., ReturnType]):
112
+ self.fxn = fxn
113
+ self.reset()
114
+
115
+ def add_buffer(self, b:Buffer) -> Buffer:
116
+ if found:=self.buffer_replace.get(b, None): return found
117
+ if b.is_allocated() or b.lb_refcount > 0: return b
118
+ if b._base is not None:
119
+ self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
120
+ else:
121
+ self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
122
+ return ret
123
+
124
+ def add(self, ei:ExecItem):
125
+ self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
126
+
127
+ def reset(self):
128
+ self.jit_cache: List[ExecItem] = []
129
+ self.input_replace: Dict[Tuple[int, int], int] = {}
130
+ self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
131
+ self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
132
+ self.cnt: int = 0
133
+
134
+ def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
135
+
136
+ def __call__(self, *args, **kwargs) -> ReturnType:
137
+ input_tensors: List[Tuple[Union[int, str], Tensor]] = \
138
+ [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
139
+ if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
140
+ names: List[Union[int, str]] = [name for name,_ in input_tensors]
141
+ lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
142
+ st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
143
+ input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
144
+ assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
145
+ var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
146
+ [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
147
+ st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
148
+ if not JIT or self.cnt == 0:
149
+ if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
150
+ # jit ignore
151
+ with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
152
+ self.ret = self.fxn(*args, **kwargs)
153
+ if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
154
+ elif self.cnt == 1:
155
+ # jit capture
156
+ self.expected_names: List[Union[int, str]] = names
157
+ self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
158
+ with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
159
+ capturing.append(self)
160
+ self.ret = self.fxn(*args, **kwargs)
161
+ if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
162
+ capturing.clear()
163
+ del self.buffer_replace
164
+ assert len(self.jit_cache), "didn't JIT anything!"
165
+ if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
166
+
167
+ # track inputs that are views of buffers
168
+ for item in self.jit_cache:
169
+ for b in item.bufs:
170
+ if b is not None and b._base is not None and b._base in input_buffers:
171
+ input_buffers.append(b)
172
+ self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
173
+
174
+ # memory planning (optional)
175
+ assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
176
+ self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
177
+
178
+ # Condense the items into a graph executor.
179
+ if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
180
+
181
+ self.input_replace = get_input_replace(self.jit_cache, input_buffers)
182
+ if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
183
+ elif self.cnt >= 2:
184
+ # jit exec
185
+ assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
186
+ assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
187
+ f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
188
+ for idx, offset, device, size, dtype in self.extra_view_inputs:
189
+ input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
190
+ for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
191
+ if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
192
+ for ei in self.jit_cache: ei.run(var_vals, jit=True)
193
+
194
+ # clear jit inputs
195
+ for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
196
+
197
+ self.cnt += 1
198
+ return self.ret
@@ -0,0 +1,192 @@
1
+ from typing import List, Dict, Optional, cast, Generator, Tuple
2
+ import time
3
+ from dataclasses import dataclass, replace
4
+ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
5
+ from tinygrad.ops import BufferOps, LoadOps, LazyOp
6
+ from tinygrad.device import Device, Buffer
7
+ from tinygrad.shape.symbolic import Variable, sym_infer, sint
8
+ from tinygrad.renderer import Renderer, Program
9
+ from tinygrad.codegen.linearizer import Linearizer
10
+ from tinygrad.engine.schedule import ScheduleItem
11
+
12
+ # **************** Program Creation ****************
13
+
14
+ logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
+ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
16
+ if DEBUG >= 3:
17
+ from tinygrad.engine.graph import print_tree
18
+ for op in ast: print_tree(op)
19
+ k = Linearizer(*ast, opts=renderer)
20
+ k.required_optimizations()
21
+ if not NOOPT:
22
+ if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
23
+ if BEAM >= 1:
24
+ from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
25
+ kb, k_opt = Linearizer(*ast, opts=renderer), k
26
+ kb.required_optimizations()
27
+ rawbufs = bufs_from_lin(kb, allocate=False)
28
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
29
+ if getenv("BEAM_COMPARE", 1):
30
+ # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
31
+ lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
32
+ if used_tensor_cores:
33
+ lins.append(("hc", Linearizer(*ast, opts=renderer)))
34
+ lins[-1][1].hand_coded_optimizations()
35
+ timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
36
+ if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
37
+ k = timed[0][1]
38
+ if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
39
+ # TODO: check the correctness inline once compare_linearizer is in core
40
+ if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
41
+ if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
42
+ return k
43
+
44
+ # **************** Runners ****************
45
+
46
+ class Runner:
47
+ def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
48
+ self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
49
+ @property
50
+ def device(self): return Device[self.dname]
51
+ def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
52
+ return self(rawbufs, {} if var_vals is None else var_vals)
53
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
54
+ raise NotImplementedError("override this")
55
+
56
+ class CompiledRunner(Runner):
57
+ def __init__(self, p:Program, precompiled:Optional[bytes]=None):
58
+ if DEBUG >= 4: print(p.src)
59
+ self.p:Program = p
60
+ self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
61
+ self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
62
+ super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
63
+
64
+ def __reduce__(self): return self.__class__, (self.p, self.lib)
65
+
66
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
67
+ global_size, local_size = self.p.launch_dims(var_vals)
68
+ if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
69
+ # TODO: this is copied from get_program
70
+ from tinygrad.engine.search import optimize_local_size
71
+ local_size = optimize_local_size(self.clprg, global_size, rawbufs)
72
+ global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
73
+ self.p = replace(self.p, global_size=global_size, local_size=local_size)
74
+ lra = {}
75
+ if global_size:
76
+ lra['global_size'] = global_size
77
+ assert len(global_size) == 3, "global size must have len 3"
78
+ if local_size:
79
+ lra['local_size'] = local_size
80
+ assert len(local_size) == 3, "local size must have len 3"
81
+ return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
82
+
83
+ class CustomOp(Runner):
84
+ def __init__(self, fxn):
85
+ self.fxn = fxn
86
+ super().__init__(self.fxn.__name__, "CUSTOM", 0, 0)
87
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs)
88
+
89
+ class EmptyOp(Runner):
90
+ def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
91
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
92
+
93
+ class ViewOp(Runner):
94
+ def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
95
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
96
+ assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
97
+
98
+ class BufferCopy(Runner):
99
+ def __init__(self, total_sz, dest_device, src_device):
100
+ if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
101
+ else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
102
+ super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
103
+ def copy(self, dest, src):
104
+ disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
105
+ if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
106
+ dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
107
+ elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
108
+ # fast(ish) path, uses readinto in diskbuffers
109
+ src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
110
+ else:
111
+ dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
112
+ def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
113
+ dest, src = rawbufs[0:2]
114
+ assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
115
+ st = time.perf_counter()
116
+ self.copy(dest, src)
117
+ if wait:
118
+ Device[dest.device].synchronize()
119
+ return time.perf_counter() - st
120
+
121
+ class BufferXfer(BufferCopy):
122
+ def copy(self, dest, src):
123
+ if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
124
+ dest.allocator.device.track_cross_buffer.append(src)
125
+ src.allocator.track_cross_device.add(dest.allocator.device)
126
+ dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
127
+
128
+ # **************** method cache ****************
129
+
130
+ method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
131
+ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
132
+ ckey = (dname, ast, BEAM.value, False)
133
+ if cret:=method_cache.get(ckey): return cret
134
+ bkey = (dname.split(":")[0], ast, BEAM.value, True)
135
+ if bret:=method_cache.get(bkey):
136
+ method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
137
+ else:
138
+ prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
139
+ if hasattr(prg.uops, "fuzz_paths"):
140
+ from test.external.fuzz_uops import UOpsFuzzerRunner
141
+ return UOpsFuzzerRunner(replace(prg, dname=dname))
142
+ method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
143
+ return ret
144
+
145
+ # **************** lowering functions ****************
146
+
147
+ @dataclass(frozen=True)
148
+ class ExecItem:
149
+ prg: Runner
150
+ bufs: List[Optional[Buffer]]
151
+ def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
152
+ bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
153
+ et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
154
+ if do_update_stats:
155
+ GlobalCounters.kernel_count += 1
156
+ GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
157
+ GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
158
+ if et is not None: GlobalCounters.time_sum_s += et
159
+ if DEBUG >= 2:
160
+ ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
161
+ print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
162
+ (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
163
+ self.prg.first_run = False
164
+ return et
165
+
166
+ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
167
+ assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
168
+ if si.ast[0].op is BufferOps.STORE:
169
+ runner = get_runner(si.outputs[0].device, si.ast)
170
+ return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
171
+ out, ast = si.outputs[0], si.ast[0]
172
+ if ast.op is LoadOps.COPY:
173
+ kernel_type = BufferCopy
174
+ if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
175
+ kernel_type = BufferXfer
176
+ return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
177
+ if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
178
+ if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
179
+ if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
180
+ raise RuntimeError(f"don't know how to lower {ast}")
181
+
182
+ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
183
+ while len(schedule): yield lower_schedule_item(schedule.pop(0))
184
+
185
+ # **************** main run function ****************
186
+
187
+ capturing: List = [] # put classes with an add method in here
188
+
189
+ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
190
+ for ei in lower_schedule(schedule):
191
+ if len(capturing) and CAPTURING: capturing[0].add(ei)
192
+ ei.run(var_vals, do_update_stats=do_update_stats)