tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
1
- from typing import List, Dict, Optional, cast, Generator, Tuple
1
+ from typing import Optional, cast, Generator
2
2
  import time, pprint
3
3
  from dataclasses import dataclass, replace
4
- from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
5
- from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
6
- from tinygrad.dtype import dtypes
4
+ from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
5
+ from tinygrad.helpers import DEVECTORIZE, time_to_str
6
+ from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
7
7
  from tinygrad.device import Device, Buffer
8
- from tinygrad.renderer import Renderer, Program
8
+ from tinygrad.renderer import Renderer, ProgramSpec, Estimates
9
9
  from tinygrad.codegen.kernel import Kernel
10
10
  from tinygrad.engine.schedule import ScheduleItem
11
11
 
@@ -13,50 +13,15 @@ from tinygrad.engine.schedule import ScheduleItem
13
13
 
14
14
  logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
15
  def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
16
- if DEBUG >= 5:
17
- print(ast)
16
+ if DEBUG >= 5: print(ast)
18
17
  k = Kernel(ast, opts=renderer).required_optimizations()
19
18
  if not NOOPT:
20
- if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
19
+ if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
21
20
  if BEAM >= 1:
22
- from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
23
- kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
21
+ from tinygrad.engine.search import beam_search, bufs_from_lin
22
+ kb = Kernel(ast, opts=renderer).required_optimizations()
24
23
  rawbufs = bufs_from_lin(kb, allocate=False)
25
- if BEAM.value >= 100:
26
- from extra.mcts_search import mcts_search
27
- k = mcts_search(kb, rawbufs, BEAM.value)
28
- else:
29
- k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
30
- if beam_compare:=getenv("BEAM_COMPARE", 1):
31
- # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
32
- lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
33
- if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
34
- timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
35
- if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
36
- k = timed[0][1]
37
- if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
38
- if beam_compare == 2:
39
- from tinygrad import Tensor
40
- all_outs: List[List[Tensor]] = []
41
- with Context(DEBUG=0, BEAM=0, CAPTURING=0):
42
- rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
43
- (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
44
- Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
45
- for buf in rawbufs]
46
- for _, tk in lins[::-1]:
47
- for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
48
- time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
49
- all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
50
- with Context(DEBUG=0, BEAM=0, CAPTURING=0):
51
- for bufs in zip(*all_outs):
52
- for b in bufs[1:]:
53
- if dtypes.is_float(bufs[0].dtype):
54
- # we check both atol and rtol here
55
- diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
56
- else:
57
- diff_count = (b != bufs[0]).sum().item()
58
- if diff_count != 0:
59
- raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
24
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
60
25
  if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
61
26
  if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
62
27
  return k
@@ -64,33 +29,32 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
64
29
  # **************** Runners ****************
65
30
 
66
31
  class Runner:
67
- def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
68
- self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
69
- True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
32
+ def __init__(self, display_name:str, device:str, estimates=Estimates()):
33
+ self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
70
34
  @property
71
- def device(self): return Device[self.dname]
72
- def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
35
+ def dev(self): return Device[self.device]
36
+ def exec(self, rawbufs:list[Buffer], var_vals:Optional[dict[Variable, int]]=None) -> Optional[float]:
73
37
  return self(rawbufs, {} if var_vals is None else var_vals)
74
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
38
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
75
39
  raise NotImplementedError("override this")
76
40
 
77
41
  class CompiledRunner(Runner):
78
- def __init__(self, p:Program, precompiled:Optional[bytes]=None):
42
+ def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
79
43
  if DEBUG >= 4: print(p.src)
80
- self.p:Program = p
81
- self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
82
- if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib)
83
- self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
84
- super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
44
+ self.p:ProgramSpec = p
45
+ self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
46
+ if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
47
+ self._prg = Device[p.device].runtime(p.function_name, self.lib)
48
+ super().__init__(p.name, p.device, p.estimates)
85
49
 
86
50
  def __reduce__(self): return self.__class__, (self.p, self.lib)
87
51
 
88
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
52
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
89
53
  global_size, local_size = self.p.launch_dims(var_vals)
90
54
  if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
91
55
  # TODO: this is copied from get_program
92
56
  from tinygrad.engine.search import optimize_local_size
93
- local_size = optimize_local_size(self.clprg, global_size, rawbufs)
57
+ local_size = optimize_local_size(self._prg, global_size, rawbufs)
94
58
  global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
95
59
  self.p = replace(self.p, global_size=global_size, local_size=local_size)
96
60
  lra = {}
@@ -100,33 +64,29 @@ class CompiledRunner(Runner):
100
64
  if local_size:
101
65
  lra['local_size'] = tuple(local_size)
102
66
  assert len(local_size) == 3, "local size must have len 3"
103
- return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
104
-
105
- class EmptyOp(Runner):
106
- def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
107
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
67
+ return self._prg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
108
68
 
109
69
  class ViewOp(Runner):
110
70
  def __init__(self, buf:Buffer): super().__init__(colored(f"view {buf.nbytes:8d} @ {buf.offset:<10d}", "yellow"), buf.device)
111
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
71
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
112
72
  assert rawbufs[0]._base is not None and rawbufs[0]._base == rawbufs[1].base, f"must be base {rawbufs}"
113
73
 
114
74
  class BufferCopy(Runner):
115
75
  def __init__(self, total_sz, dest_device, src_device):
116
76
  if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
117
77
  else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
118
- super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
78
+ super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
119
79
  def copy(self, dest, src):
120
- disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and \
121
- getattr(src.allocator.device, 'fd', None) is not None
80
+ disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
81
+ getattr(src.allocator.dev, 'fd', None) is not None
122
82
  if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
123
83
  dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
124
- elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
84
+ elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
125
85
  # fast(ish) path, uses readinto in diskbuffers
126
- src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
86
+ src.allocator._copyout(dest.allocator._as_buffer(dest._buf), src._buf)
127
87
  else:
128
88
  dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy
129
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
89
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False):
130
90
  dest, src = rawbufs[0:2]
131
91
  assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}"
132
92
  st = time.perf_counter()
@@ -136,23 +96,22 @@ class BufferCopy(Runner):
136
96
  return time.perf_counter() - st
137
97
 
138
98
  class BufferXfer(BufferCopy):
139
- def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
99
+ def copy(self, dest, src): dest.allocator._transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.dev, dest_dev=dest.allocator.dev)
140
100
 
141
101
  # **************** method cache ****************
142
102
 
143
- method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {}
144
- def get_runner(dname:str, ast:UOp) -> CompiledRunner:
145
- ckey = (dname, ast.key, BEAM.value, NOOPT.value, False)
103
+ method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
104
+ def get_runner(device:str, ast:UOp) -> CompiledRunner:
105
+ # TODO: this should be all context relevant to rendering
106
+ context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
107
+ ckey = (device, ast.key, context, False)
146
108
  if cret:=method_cache.get(ckey): return cret
147
- bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
109
+ bkey = (device.split(":")[0], ast.key, context, True)
148
110
  if bret:=method_cache.get(bkey):
149
- method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
111
+ method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
150
112
  else:
151
- prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
152
- if getenv("FUZZ_UOPS"):
153
- from test.external.fuzz_uops import UOpsFuzzerRunner
154
- return UOpsFuzzerRunner(replace(prg, dname=dname))
155
- method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
113
+ prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
114
+ method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
156
115
  return ret
157
116
 
158
117
  # **************** lowering functions ****************
@@ -160,43 +119,38 @@ def get_runner(dname:str, ast:UOp) -> CompiledRunner:
160
119
  @dataclass(frozen=True)
161
120
  class ExecItem:
162
121
  prg: Runner
163
- bufs: List[Optional[Buffer]]
164
- metadata: Optional[Tuple[Metadata, ...]] = None
165
- def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
122
+ bufs: list[Optional[Buffer]]
123
+ metadata: Optional[tuple[Metadata, ...]] = None
124
+ def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
166
125
  var_vals = {} if _var_vals is None else _var_vals
167
126
  bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
168
127
  et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
169
128
  if do_update_stats:
170
129
  GlobalCounters.kernel_count += 1
171
- GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
172
- GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals))
130
+ GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals))
131
+ GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals))
173
132
  if et is not None: GlobalCounters.time_sum_s += et
174
133
  if DEBUG >= 2:
175
- lds_est = sym_infer(self.prg.lds_estimate, var_vals)
134
+ lds_est = sym_infer(self.prg.estimates.lds, var_vals)
176
135
  mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
177
- ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
178
- print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
136
+ ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
137
+ print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
179
138
  (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
180
139
  f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
181
140
  self.prg.first_run = False
182
141
  return et
183
142
 
184
- def lower_schedule_item(si:ScheduleItem) -> ExecItem:
185
- assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
186
- if si.ast.op is Ops.SINK:
187
- runner = get_runner(si.outputs[0].device, si.ast)
188
- return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
189
- out, arg = si.outputs[0], si.ast.arg
190
- if si.ast.op is Ops.COPY:
191
- kernel_type = BufferCopy
192
- if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
193
- kernel_type = BufferXfer
194
- return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
195
- if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
196
- if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
197
- raise RuntimeError(f"don't know how to lower {si.ast}")
198
-
199
- def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
143
+ # NOTE: ctx is the buffers
144
+ si_lowerer = PatternMatcher([
145
+ (UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
146
+ (UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
147
+ (UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
148
+ if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
149
+ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
150
+ ])
151
+ def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
152
+
153
+ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
200
154
  while len(schedule):
201
155
  si = schedule.pop(0)
202
156
  try: yield lower_schedule_item(si)
@@ -209,9 +163,9 @@ def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, Non
209
163
 
210
164
  # **************** main run function ****************
211
165
 
212
- capturing: List = [] # put classes with an add method in here
166
+ capturing: list = [] # put classes with an add method in here
213
167
 
214
- def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
168
+ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
215
169
  for ei in lower_schedule(schedule):
216
170
  if len(capturing) and CAPTURING: capturing[0].add(ei)
217
171
  ei.run(var_vals, do_update_stats=do_update_stats)