tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,10 @@
1
- from typing import List, Dict, Optional, cast, Generator, Tuple
1
+ from typing import Optional, cast, Generator
2
2
  import time, pprint
3
3
  from dataclasses import dataclass, replace
4
- from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
5
- from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
6
- from tinygrad.dtype import dtypes
4
+ from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
5
+ from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
7
6
  from tinygrad.device import Device, Buffer
8
- from tinygrad.renderer import Renderer, Program
7
+ from tinygrad.renderer import Renderer, ProgramSpec, Estimates
9
8
  from tinygrad.codegen.kernel import Kernel
10
9
  from tinygrad.engine.schedule import ScheduleItem
11
10
 
@@ -13,50 +12,15 @@ from tinygrad.engine.schedule import ScheduleItem
13
12
 
14
13
  logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
14
  def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
16
- if DEBUG >= 5:
17
- print(ast)
15
+ if DEBUG >= 5: print(ast)
18
16
  k = Kernel(ast, opts=renderer).required_optimizations()
19
17
  if not NOOPT:
20
- if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
18
+ if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
21
19
  if BEAM >= 1:
22
- from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
23
- kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
20
+ from tinygrad.engine.search import beam_search, bufs_from_lin
21
+ kb = Kernel(ast, opts=renderer).required_optimizations()
24
22
  rawbufs = bufs_from_lin(kb, allocate=False)
25
- if BEAM.value >= 100:
26
- from extra.mcts_search import mcts_search
27
- k = mcts_search(kb, rawbufs, BEAM.value)
28
- else:
29
- k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
30
- if beam_compare:=getenv("BEAM_COMPARE", 1):
31
- # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
32
- lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
33
- if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
34
- timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
35
- if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
36
- k = timed[0][1]
37
- if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
38
- if beam_compare == 2:
39
- from tinygrad import Tensor
40
- all_outs: List[List[Tensor]] = []
41
- with Context(DEBUG=0, BEAM=0, CAPTURING=0):
42
- rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
43
- (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
44
- Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
45
- for buf in rawbufs]
46
- for _, tk in lins[::-1]:
47
- for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
48
- time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
49
- all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
50
- with Context(DEBUG=0, BEAM=0, CAPTURING=0):
51
- for bufs in zip(*all_outs):
52
- for b in bufs[1:]:
53
- if dtypes.is_float(bufs[0].dtype):
54
- # we check both atol and rtol here
55
- diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
56
- else:
57
- diff_count = (b != bufs[0]).sum().item()
58
- if diff_count != 0:
59
- raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
23
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
60
24
  if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
61
25
  if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
62
26
  return k
@@ -64,33 +28,32 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
64
28
  # **************** Runners ****************
65
29
 
66
30
  class Runner:
67
- def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
68
- self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
69
- True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
31
+ def __init__(self, display_name:str, device:str, estimates=Estimates()):
32
+ self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
70
33
  @property
71
- def device(self): return Device[self.dname]
72
- def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
34
+ def dev(self): return Device[self.device]
35
+ def exec(self, rawbufs:list[Buffer], var_vals:Optional[dict[Variable, int]]=None) -> Optional[float]:
73
36
  return self(rawbufs, {} if var_vals is None else var_vals)
74
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
37
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
75
38
  raise NotImplementedError("override this")
76
39
 
77
40
  class CompiledRunner(Runner):
78
- def __init__(self, p:Program, precompiled:Optional[bytes]=None):
41
+ def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
79
42
  if DEBUG >= 4: print(p.src)
80
- self.p:Program = p
81
- self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
82
- if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib)
83
- self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
84
- super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
43
+ self.p:ProgramSpec = p
44
+ self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
45
+ if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
46
+ self._prg = Device[p.device].runtime(p.function_name, self.lib)
47
+ super().__init__(p.name, p.device, p.estimates)
85
48
 
86
49
  def __reduce__(self): return self.__class__, (self.p, self.lib)
87
50
 
88
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
51
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
89
52
  global_size, local_size = self.p.launch_dims(var_vals)
90
53
  if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
91
54
  # TODO: this is copied from get_program
92
55
  from tinygrad.engine.search import optimize_local_size
93
- local_size = optimize_local_size(self.clprg, global_size, rawbufs)
56
+ local_size = optimize_local_size(self._prg, global_size, rawbufs)
94
57
  global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
95
58
  self.p = replace(self.p, global_size=global_size, local_size=local_size)
96
59
  lra = {}
@@ -100,33 +63,29 @@ class CompiledRunner(Runner):
100
63
  if local_size:
101
64
  lra['local_size'] = tuple(local_size)
102
65
  assert len(local_size) == 3, "local size must have len 3"
103
- return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
104
-
105
- class EmptyOp(Runner):
106
- def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
107
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
66
+ return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
108
67
 
109
68
  class ViewOp(Runner):
110
69
  def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
111
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
70
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
112
71
  assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
113
72
 
114
73
  class BufferCopy(Runner):
115
74
  def __init__(self, total_sz, dest_device, src_device):
116
75
  if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
117
76
  else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
118
- super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
77
+ super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
119
78
  def copy(self, dest, src):
120
- disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and \
121
- getattr(src.allocator.device, 'fd', None) is not None
79
+ disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
80
+ getattr(src.allocator.dev, 'fd', None) is not None
122
81
  if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
123
82
  dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
124
- elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
83
+ elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
125
84
  # fast(ish) path, uses readinto in diskbuffers
126
- src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
85
+ src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
127
86
  else:
128
87
  dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
129
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
88
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
130
89
  dest, src = rawbufs[0:2]
131
90
  assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
132
91
  st = time.perf_counter()
@@ -136,23 +95,20 @@ class BufferCopy(Runner):
136
95
  return time.perf_counter() - st
137
96
 
138
97
  class BufferXfer(BufferCopy):
139
- def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
98
+ def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
140
99
 
141
100
  # **************** method cache ****************
142
101
 
143
- method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {}
144
- def get_runner(dname:str, ast:UOp) -> CompiledRunner:
145
- ckey = (dname, ast.key, BEAM.value, NOOPT.value, False)
102
+ method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {}
103
+ def get_runner(device:str, ast:UOp) -> CompiledRunner:
104
+ ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
146
105
  if cret:=method_cache.get(ckey): return cret
147
- bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
106
+ bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
148
107
  if bret:=method_cache.get(bkey):
149
- method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
108
+ method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
150
109
  else:
151
- prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
152
- if getenv("FUZZ_UOPS"):
153
- from test.external.fuzz_uops import UOpsFuzzerRunner
154
- return UOpsFuzzerRunner(replace(prg, dname=dname))
155
- method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
110
+ prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
111
+ method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
156
112
  return ret
157
113
 
158
114
  # **************** lowering functions ****************
@@ -160,43 +116,38 @@ def get_runner(dname:str, ast:UOp) -> CompiledRunner:
160
116
  @dataclass(frozen=True)
161
117
  class ExecItem:
162
118
  prg: Runner
163
- bufs: List[Optional[Buffer]]
164
- metadata: Optional[Tuple[Metadata, ...]] = None
165
- def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
119
+ bufs: list[Optional[Buffer]]
120
+ metadata: Optional[tuple[Metadata, ...]] = None
121
+ def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
166
122
  var_vals = {} if _var_vals is None else _var_vals
167
123
  bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
168
124
  et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
169
125
  if do_update_stats:
170
126
  GlobalCounters.kernel_count += 1
171
- GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
172
- GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals))
127
+ GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
128
+ GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
173
129
  if et is not None: GlobalCounters.time_sum_s += et
174
130
  if DEBUG >= 2:
175
- lds_est = sym_infer(self.prg.lds_estimate, var_vals)
131
+ lds_est = sym_infer(self.prg.estimates.lds, var_vals)
176
132
  mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
177
133
  ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
178
- print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
134
+ print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
179
135
  (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
180
136
  f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
181
137
  self.prg.first_run = False
182
138
  return et
183
139
 
184
- def lower_schedule_item(si:ScheduleItem) -> ExecItem:
185
- assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
186
- if si.ast.op is Ops.SINK:
187
- runner = get_runner(si.outputs[0].device, si.ast)
188
- return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
189
- out, arg = si.outputs[0], si.ast.arg
190
- if si.ast.op is Ops.COPY:
191
- kernel_type = BufferCopy
192
- if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
193
- kernel_type = BufferXfer
194
- return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
195
- if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
196
- if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
197
- raise RuntimeError(f"don't know how to lower {si.ast}")
198
-
199
- def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
140
+ # NOTE: ctx is the buffers
141
+ si_lowerer = PatternMatcher([
142
+ (UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
143
+ (UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
144
+ (UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
145
+ if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
146
+ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
147
+ ])
148
+ def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
149
+
150
+ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
200
151
  while len(schedule):
201
152
  si = schedule.pop(0)
202
153
  try: yield lower_schedule_item(si)
@@ -209,9 +160,9 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
209
160
 
210
161
  # **************** main run function ****************
211
162
 
212
- capturing: List = [] # put classes with an add method in here
163
+ capturing: list = [] # put classes with an add method in here
213
164
 
214
- def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
165
+ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
215
166
  for ei in lower_schedule(schedule):
216
167
  if len(capturing) and CAPTURING: capturing[0].add(ei)
217
168
  ei.run(var_vals, do_update_stats=do_update_stats)