tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ from typing import List, Union, Tuple, Dict
2
+ from collections import defaultdict
3
+ from tinygrad.engine.schedule import ScheduleItem
4
+ from tinygrad.device import Device, Buffer
5
+ from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
6
+ from tinygrad.ops import Ops
7
+
8
+ # **************** memory planning ****************
9
+
10
+ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
11
+ if NO_MEMORY_PLANNER: return {}
12
+ first_appearance, last_appearance = {}, {}
13
+ for i,u in enumerate(buffers):
14
+ for buf in u:
15
+ if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
16
+ if buf.base not in first_appearance: first_appearance[buf.base] = i
17
+ last_appearance[buf.base] = i
18
+
19
+ # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
20
+ # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
21
+ free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
22
+ def find_replace_buffer(buf, st, en):
23
+ key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
24
+
25
+ default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
26
+ seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
27
+
28
+ free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
29
+ free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
30
+
31
+ return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
32
+
33
+ buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
34
+ assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
35
+
36
+ for i,u in enumerate(buffers):
37
+ for buf in u:
38
+ if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
39
+ if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
40
+ else: assigned[buf] = assigned.get(buf, buf)
41
+
42
+ if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
43
+ print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
44
+ f"{len(ak)} -> {len(av)} bufs")
45
+ return assigned
46
+
47
+ def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
48
+ # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
49
+ assigned = _internal_memory_planner([si.bufs for si in schedule],
50
+ noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
51
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
@@ -1,42 +1,62 @@
1
1
  from typing import List, Dict, Optional, cast, Generator, Tuple
2
- import time
2
+ import time, pprint
3
3
  from dataclasses import dataclass, replace
4
- from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
5
- from tinygrad.ops import BufferOps, LoadOps, LazyOp
4
+ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
5
+ from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
6
+ from tinygrad.dtype import dtypes
6
7
  from tinygrad.device import Device, Buffer
7
- from tinygrad.shape.symbolic import Variable, sym_infer, sint
8
8
  from tinygrad.renderer import Renderer, Program
9
- from tinygrad.codegen.linearizer import Linearizer
9
+ from tinygrad.codegen.kernel import Kernel
10
10
  from tinygrad.engine.schedule import ScheduleItem
11
11
 
12
12
  # **************** Program Creation ****************
13
13
 
14
14
  logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
- def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
16
- if DEBUG >= 3:
17
- from tinygrad.engine.graph import print_tree
18
- for op in ast: print_tree(op)
19
- k = Linearizer(*ast, opts=renderer)
20
- k.required_optimizations()
15
+ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
16
+ if DEBUG >= 5:
17
+ print(ast)
18
+ k = Kernel(ast, opts=renderer).required_optimizations()
21
19
  if not NOOPT:
22
20
  if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
23
21
  if BEAM >= 1:
24
22
  from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
25
- kb, k_opt = Linearizer(*ast, opts=renderer), k
26
- kb.required_optimizations()
23
+ kb, k_opt = Kernel(ast, opts=renderer).required_optimizations(), k
27
24
  rawbufs = bufs_from_lin(kb, allocate=False)
28
- k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
29
- if getenv("BEAM_COMPARE", 1):
25
+ if BEAM.value >= 100:
26
+ from extra.mcts_search import mcts_search
27
+ k = mcts_search(kb, rawbufs, BEAM.value)
28
+ else:
29
+ k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
30
+ if beam_compare:=getenv("BEAM_COMPARE", 1):
30
31
  # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
31
- lins: List[Tuple[str, Linearizer]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
32
- if used_tensor_cores:
33
- lins.append(("hc", Linearizer(*ast, opts=renderer)))
34
- lins[-1][1].hand_coded_optimizations()
32
+ lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
33
+ if used_tensor_cores: lins.append(("hc", Kernel(ast, opts=renderer).hand_coded_optimizations()))
35
34
  timed = sorted([(nm, tk, time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
36
- if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
35
+ if DEBUG >= 3: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
37
36
  k = timed[0][1]
38
37
  if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
39
- # TODO: check the correctness inline once compare_linearizer is in core
38
+ if beam_compare == 2:
39
+ from tinygrad import Tensor
40
+ all_outs: List[List[Tensor]] = []
41
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0):
42
+ rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
43
+ (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
44
+ Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
45
+ for buf in rawbufs]
46
+ for _, tk in lins[::-1]:
47
+ for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
48
+ time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
49
+ all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
50
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0):
51
+ for bufs in zip(*all_outs):
52
+ for b in bufs[1:]:
53
+ if dtypes.is_float(bufs[0].dtype):
54
+ # we check both atol and rtol here
55
+ diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
56
+ else:
57
+ diff_count = (b != bufs[0]).sum().item()
58
+ if diff_count != 0:
59
+ raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
40
60
  if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
41
61
  if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
42
62
  return k
@@ -44,8 +64,9 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
44
64
  # **************** Runners ****************
45
65
 
46
66
  class Runner:
47
- def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0):
48
- self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate
67
+ def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
68
+ self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
69
+ True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
49
70
  @property
50
71
  def device(self): return Device[self.dname]
51
72
  def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
@@ -58,8 +79,9 @@ class CompiledRunner(Runner):
58
79
  if DEBUG >= 4: print(p.src)
59
80
  self.p:Program = p
60
81
  self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
82
+ if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib)
61
83
  self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
62
- super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate)
84
+ super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
63
85
 
64
86
  def __reduce__(self): return self.__class__, (self.p, self.lib)
65
87
 
@@ -73,19 +95,13 @@ class CompiledRunner(Runner):
73
95
  self.p = replace(self.p, global_size=global_size, local_size=local_size)
74
96
  lra = {}
75
97
  if global_size:
76
- lra['global_size'] = global_size
98
+ lra['global_size'] = tuple(global_size)
77
99
  assert len(global_size) == 3, "global size must have len 3"
78
100
  if local_size:
79
- lra['local_size'] = local_size
101
+ lra['local_size'] = tuple(local_size)
80
102
  assert len(local_size) == 3, "local size must have len 3"
81
103
  return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait)
82
104
 
83
- class CustomOp(Runner):
84
- def __init__(self, fxn):
85
- self.fxn = fxn
86
- super().__init__(self.fxn.__name__, "CUSTOM", 0, 0)
87
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs)
88
-
89
105
  class EmptyOp(Runner):
90
106
  def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device)
91
107
  def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
@@ -101,7 +117,8 @@ class BufferCopy(Runner):
101
117
  else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
102
118
  super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
103
119
  def copy(self, dest, src):
104
- disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
120
+ disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and \
121
+ getattr(src.allocator.device, 'fd', None) is not None
105
122
  if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
106
123
  dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
107
124
  elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
@@ -119,24 +136,20 @@ class BufferCopy(Runner):
119
136
  return time.perf_counter() - st
120
137
 
121
138
  class BufferXfer(BufferCopy):
122
- def copy(self, dest, src):
123
- if hasattr(dest.allocator.device, "track_cross_buffer") and hasattr(src.allocator, "track_cross_device"):
124
- dest.allocator.device.track_cross_buffer.append(src)
125
- src.allocator.track_cross_device.add(dest.allocator.device)
126
- dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
139
+ def copy(self, dest, src): dest.allocator.transfer(dest._buf, src._buf, dest.nbytes, src_dev=src.allocator.device, dest_dev=dest.allocator.device)
127
140
 
128
141
  # **************** method cache ****************
129
142
 
130
- method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], int, bool], CompiledRunner] = {}
131
- def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
132
- ckey = (dname, ast, BEAM.value, False)
143
+ method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {}
144
+ def get_runner(dname:str, ast:UOp) -> CompiledRunner:
145
+ ckey = (dname, ast.key, BEAM.value, NOOPT.value, False)
133
146
  if cret:=method_cache.get(ckey): return cret
134
- bkey = (dname.split(":")[0], ast, BEAM.value, True)
147
+ bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
135
148
  if bret:=method_cache.get(bkey):
136
149
  method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
137
150
  else:
138
- prg: Program = get_linearizer(Device[dname].renderer, ast).to_program()
139
- if hasattr(prg.uops, "fuzz_paths"):
151
+ prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
152
+ if getenv("FUZZ_UOPS"):
140
153
  from test.external.fuzz_uops import UOpsFuzzerRunner
141
154
  return UOpsFuzzerRunner(replace(prg, dname=dname))
142
155
  method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
@@ -148,39 +161,51 @@ def get_runner(dname:str, ast:Tuple[LazyOp, ...]) -> CompiledRunner:
148
161
  class ExecItem:
149
162
  prg: Runner
150
163
  bufs: List[Optional[Buffer]]
151
- def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
164
+ metadata: Optional[Tuple[Metadata, ...]] = None
165
+ def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
166
+ var_vals = {} if _var_vals is None else _var_vals
152
167
  bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
153
- et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
168
+ et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
154
169
  if do_update_stats:
155
170
  GlobalCounters.kernel_count += 1
156
- GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
157
- GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals))
171
+ GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
172
+ GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals))
158
173
  if et is not None: GlobalCounters.time_sum_s += et
159
174
  if DEBUG >= 2:
175
+ lds_est = sym_infer(self.prg.lds_estimate, var_vals)
176
+ mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
160
177
  ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
161
- print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
162
- (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
178
+ print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
179
+ (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
180
+ f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
163
181
  self.prg.first_run = False
164
182
  return et
165
183
 
166
184
  def lower_schedule_item(si:ScheduleItem) -> ExecItem:
167
- assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY or getenv("USE_COPY_KERNEL")
168
- if si.ast[0].op is BufferOps.STORE:
185
+ assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
186
+ if si.ast.op is Ops.SINK:
169
187
  runner = get_runner(si.outputs[0].device, si.ast)
170
- return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals])
171
- out, ast = si.outputs[0], si.ast[0]
172
- if ast.op is LoadOps.COPY:
188
+ return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
189
+ out, arg = si.outputs[0], si.ast.arg
190
+ if si.ast.op is Ops.COPY:
173
191
  kernel_type = BufferCopy
174
192
  if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
175
193
  kernel_type = BufferXfer
176
- return ExecItem(kernel_type(ast.arg, out.device, si.inputs[0].device), list(si.bufs))
177
- if ast.op is LoadOps.CUSTOM: return ExecItem(CustomOp(ast.arg), list(si.bufs))
178
- if ast.op is LoadOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
179
- if ast.op is LoadOps.VIEW: return ExecItem(ViewOp(out), list(si.bufs))
180
- raise RuntimeError(f"don't know how to lower {ast}")
194
+ return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
195
+ if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
196
+ if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
197
+ raise RuntimeError(f"don't know how to lower {si.ast}")
181
198
 
182
199
  def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
183
- while len(schedule): yield lower_schedule_item(schedule.pop(0))
200
+ while len(schedule):
201
+ si = schedule.pop(0)
202
+ try: yield lower_schedule_item(si)
203
+ except Exception as e:
204
+ if DEBUG >= 2:
205
+ print(f"error lowering {si.ast.op}")
206
+ print("tensor operations:")
207
+ pprint.pprint(si.metadata, indent=2)
208
+ raise e
184
209
 
185
210
  # **************** main run function ****************
186
211