tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/memory.py CHANGED
@@ -1,50 +1,69 @@
1
+ from typing import cast
1
2
  from collections import defaultdict
2
3
  from tinygrad.engine.schedule import ScheduleItem
3
4
  from tinygrad.device import Device, Buffer
4
- from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
5
- from tinygrad.ops import Ops
5
+ from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
6
+ from tinygrad.uop.ops import Ops
7
+ from tinygrad.dtype import dtypes, ImageDType
8
+ from tinygrad.runtime.support.memory import TLSFAllocator
6
9
 
7
10
  # **************** memory planning ****************
8
11
 
9
- def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
12
+ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ignore_checks=False, debug_prefix="") -> dict[Buffer, Buffer]:
10
13
  if NO_MEMORY_PLANNER: return {}
11
- first_appearance, last_appearance = {}, {}
14
+ first_appearance, last_appearance, buf_to_opt = {}, {}, set()
12
15
  for i,u in enumerate(buffers):
13
16
  for buf in u:
14
- if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
17
+ should_skip = buf.is_allocated() or buf.base.is_allocated() or buf.uop_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers)
18
+ if not ignore_checks and should_skip: continue
15
19
  if buf.base not in first_appearance: first_appearance[buf.base] = i
16
20
  last_appearance[buf.base] = i
21
+ buf_to_opt.add(buf)
17
22
 
18
- # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
19
- # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
20
- free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
21
- def find_replace_buffer(buf, st, en):
22
- key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
23
+ # Sort buffer operations in timeline order. Two events: buffer is allocated or buffer is freed.
24
+ buffer_requests = sorted([((first_appearance[buf], True), buf) for buf in first_appearance.keys()] + \
25
+ [((last_appearance[buf] + 1, False), buf) for buf in first_appearance.keys()], key=lambda x: x[0])
23
26
 
24
- default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
25
- seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
27
+ # Try to suballocate from a shared buffer managed by global_planner using TLSFAllocator.
28
+ # Also track buffer replacements for buffers that do not support suballocation.
29
+ buffer_replace:dict[Buffer, tuple[Buffer|None, int|None]] = {}
30
+ reuse_buffers:dict[tuple, list[Buffer]] = defaultdict(list)
31
+ global_planner:dict[str, tuple[int, TLSFAllocator]] = defaultdict(lambda: (0, TLSFAllocator(1 << 44, block_size=0x1000, lv2_cnt=32)))
32
+ for (_, is_open_ev), buf in buffer_requests:
33
+ # Check if suballocation is possible for the given buffer and device.
34
+ if hasattr(Device[buf.device].allocator, "_offset") and not isinstance(buf.dtype, ImageDType):
35
+ if is_open_ev: buffer_replace[buf] = (None, global_planner[buf.device][1].alloc(round_up(buf.nbytes, 0x1000)))
36
+ else: global_planner[buf.device][1].free(cast(int, buffer_replace[buf][1]))
37
+ global_planner[buf.device] = (max(global_planner[buf.device][0], buffer_replace[buf][1] + buf.nbytes), global_planner[buf.device][1])
38
+ else:
39
+ key = (buf.device, buf.dtype, buf.options, buf.nbytes)
40
+ if is_open_ev: buffer_replace[buf] = (reuse_buffers[key].pop(), None) if key in reuse_buffers and len(reuse_buffers[key]) > 0 else (buf, None)
41
+ else: reuse_buffers[key].append(cast(Buffer, buffer_replace[buf][0]))
26
42
 
27
- free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
28
- free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
43
+ # Allocate global buffers based on the memory planner.
44
+ global_buffers = {dev: Buffer(dev, round_up(sz, 0x1000), dtypes.int8) for dev, (sz, _) in global_planner.items()}
45
+ buffer_resolve:dict[Buffer, tuple[Buffer, int|None]] = {buf: (base or global_buffers[buf.device], off) for buf,(base,off) in buffer_replace.items()}
29
46
 
30
- return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
47
+ # Assign buffers. First, assign full buffers (not sub-buffers).
48
+ assigned:dict[Buffer, Buffer] = {}
49
+ for buf, (base, off) in buffer_resolve.items():
50
+ if buf != base:
51
+ assigned[buf] = base if off is None else Buffer(buf.device, buf.size, buf.dtype, base=base, offset=off)
31
52
 
32
- buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
33
- assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
53
+ # Now assign sub-buffers.
54
+ for buf in buf_to_opt:
55
+ if buf._base is not None:
56
+ assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=(pbuf:=assigned.get(buf.base, buf.base)).base, offset=pbuf.offset+buf.offset)
34
57
 
35
- for i,u in enumerate(buffers):
36
- for buf in u:
37
- if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
38
- if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
39
- else: assigned[buf] = assigned.get(buf, buf)
58
+ if DEBUG >= 1:
59
+ ak, av = dedup(x for x in assigned.keys() if x._base is None),dedup(x for x in assigned.values() if x._base is None)+list(global_buffers.values())
60
+ omem, nmem = sum([x.nbytes for x in ak])/1e6, sum([x.nbytes for x in av])/1e6
61
+ if omem != nmem: print(f"{debug_prefix}memory reduced from {omem:.2f} MB -> {nmem:.2f} MB,", f"{len(ak)} -> {len(av)} bufs")
40
62
 
41
- if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
42
- print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
43
- f"{len(ak)} -> {len(av)} bufs")
44
63
  return assigned
45
64
 
46
65
  def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
47
66
  # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
48
- assigned = _internal_memory_planner([si.bufs for si in schedule],
67
+ assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
49
68
  noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
50
- return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
69
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
@@ -1,30 +1,51 @@
1
- from typing import Optional, cast, Generator
1
+ from typing import cast, Generator
2
2
  import time, pprint
3
- from dataclasses import dataclass, replace
4
- from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
5
- from tinygrad.helpers import DEVECTORIZE, time_to_str
6
- from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
3
+ from dataclasses import dataclass, replace, field
4
+ from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
5
+ from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile
6
+ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo
7
7
  from tinygrad.device import Device, Buffer
8
8
  from tinygrad.renderer import Renderer, ProgramSpec, Estimates
9
- from tinygrad.codegen.kernel import Kernel
10
9
  from tinygrad.engine.schedule import ScheduleItem
10
+ from tinygrad.codegen import full_rewrite
11
+ from tinygrad.codegen.opt.kernel import Opt
11
12
 
12
13
  # **************** Program Creation ****************
13
14
 
14
- logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
15
- def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
16
- if DEBUG >= 5: print(ast)
17
- k = Kernel(ast, opts=renderer).required_optimizations()
18
- if not NOOPT:
19
- if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
20
- if BEAM >= 1:
21
- from tinygrad.engine.search import beam_search, bufs_from_lin
22
- kb = Kernel(ast, opts=renderer).required_optimizations()
23
- rawbufs = bufs_from_lin(kb, allocate=False)
24
- k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
25
- if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
26
- if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
27
- return k
15
+ @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
16
+ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
17
+ """
18
+ Transform an AST into a ProgramSpec. May trigger BEAM search.
19
+
20
+ Args:
21
+ ast: The Ops.SINK rooted AST
22
+ renderer: The renderer used to generate the code
23
+
24
+ Returns:
25
+ The ProgramSpec of the program.
26
+ """
27
+
28
+ if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
29
+
30
+ # linearize
31
+ if renderer is None: renderer = Device.default.renderer
32
+ if opts is not None:
33
+ assert ast.arg is None, "can't apply opts if sink has an arg"
34
+ ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
35
+ try:
36
+ uops = full_rewrite(ast, renderer)
37
+ except RuntimeError:
38
+ print("***** LINEARIZE FAILURE *****")
39
+ print(f"ast = {ast}")
40
+ raise
41
+ assert uops[-1].op is Ops.SINK, "last uop must be sink"
42
+
43
+ # print and render
44
+ if DEBUG >= 6: print_uops(uops)
45
+ src = renderer.render(uops)
46
+
47
+ return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
48
+ global_size=[1,1,1] if renderer.has_local else None, local_size=[1,1,1] if renderer.has_local else None)
28
49
 
29
50
  # **************** Runners ****************
30
51
 
@@ -33,27 +54,30 @@ class Runner:
33
54
  self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates
34
55
  @property
35
56
  def dev(self): return Device[self.device]
36
- def exec(self, rawbufs:list[Buffer], var_vals:Optional[dict[Variable, int]]=None) -> Optional[float]:
57
+ def exec(self, rawbufs:list[Buffer], var_vals:dict[Variable, int]|None=None) -> float|None:
37
58
  return self(rawbufs, {} if var_vals is None else var_vals)
38
- def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
59
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
39
60
  raise NotImplementedError("override this")
40
61
 
41
62
  class CompiledRunner(Runner):
42
- def __init__(self, p:ProgramSpec, precompiled:Optional[bytes]=None):
63
+ def __init__(self, p:ProgramSpec, precompiled:bytes|None=None, prg=None):
43
64
  if DEBUG >= 4: print(p.src)
44
65
  self.p:ProgramSpec = p
45
- self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
46
- if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
47
- self._prg = Device[p.device].runtime(p.function_name, self.lib)
66
+ if precompiled is not None: self.lib = precompiled
67
+ else:
68
+ with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,), cat="compiler"), "TINY"):
69
+ self.lib = Device[p.device].compiler.compile_cached(p.src)
70
+ if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
71
+ self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
48
72
  super().__init__(p.name, p.device, p.estimates)
49
73
 
50
74
  def __reduce__(self): return self.__class__, (self.p, self.lib)
51
75
 
52
- def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> Optional[float]:
76
+ def __call__(self, rawbufs:list[Buffer], var_vals:dict[Variable, int], wait=False) -> float|None:
53
77
  global_size, local_size = self.p.launch_dims(var_vals)
54
78
  if global_size is not None and local_size is None and all_int(self.p.global_size): # type: ignore[arg-type]
55
79
  # TODO: this is copied from get_program
56
- from tinygrad.engine.search import optimize_local_size
80
+ from tinygrad.codegen.opt.search import optimize_local_size
57
81
  local_size = optimize_local_size(self._prg, global_size, rawbufs)
58
82
  global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
59
83
  self.p = replace(self.p, global_size=global_size, local_size=local_size)
@@ -78,7 +102,7 @@ class BufferCopy(Runner):
78
102
  super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz))
79
103
  def copy(self, dest, src):
80
104
  disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \
81
- getattr(src.allocator.dev, 'fd', None) is not None
105
+ getattr(src.allocator.dev, 'fd', None) is not None and dest.allocator.supports_copy_from_disk
82
106
  if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
83
107
  dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
84
108
  elif src.device.startswith("DISK") and hasattr(dest.allocator, '_as_buffer'):
@@ -110,7 +134,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
110
134
  if bret:=method_cache.get(bkey):
111
135
  method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
112
136
  else:
113
- prg: ProgramSpec = get_kernel(Device[device].renderer, ast).to_program()
137
+ prg: ProgramSpec = get_program(ast, Device[device].renderer)
114
138
  method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
115
139
  return ret
116
140
 
@@ -119,10 +143,11 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
119
143
  @dataclass(frozen=True)
120
144
  class ExecItem:
121
145
  prg: Runner
122
- bufs: list[Optional[Buffer]]
123
- metadata: Optional[tuple[Metadata, ...]] = None
124
- def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
125
- var_vals = {} if _var_vals is None else _var_vals
146
+ bufs: list[Buffer|None]
147
+ metadata: tuple[Metadata, ...]|None = None
148
+ fixedvars: dict[Variable, int] = field(default_factory=dict)
149
+ def run(self, _var_vals:dict[Variable, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
150
+ var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
126
151
  bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
127
152
  et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
128
153
  if do_update_stats:
@@ -134,7 +159,7 @@ class ExecItem:
134
159
  lds_est = sym_infer(self.prg.estimates.lds, var_vals)
135
160
  mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
136
161
  ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
137
- print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
162
+ print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
138
163
  (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
139
164
  f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
140
165
  self.prg.first_run = False
@@ -148,12 +173,13 @@ si_lowerer = PatternMatcher([
148
173
  if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
149
174
  else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
150
175
  ])
151
- def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
176
+ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
177
+ return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
152
178
 
153
- def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
179
+ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
154
180
  while len(schedule):
155
181
  si = schedule.pop(0)
156
- try: yield lower_schedule_item(si)
182
+ try: yield (si, lower_schedule_item(si))
157
183
  except Exception as e:
158
184
  if DEBUG >= 2:
159
185
  print(f"error lowering {si.ast.op}")
@@ -165,7 +191,22 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, Non
165
191
 
166
192
  capturing: list = [] # put classes with an add method in here
167
193
 
168
- def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
169
- for ei in lower_schedule(schedule):
194
+ def run_schedule(schedule:list[ScheduleItem], var_vals:dict[Variable, int]|None=None, do_update_stats=True):
195
+ for si, ei in lower_schedule(schedule):
170
196
  if len(capturing) and CAPTURING: capturing[0].add(ei)
171
- ei.run(var_vals, do_update_stats=do_update_stats)
197
+ if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
198
+ # copy in allocated buffers from the GPU
199
+ nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
200
+ for cpu_b, gpu_b in zip(nb, si.bufs):
201
+ if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
202
+
203
+ # run on GPU
204
+ ei.run(var_vals, do_update_stats=do_update_stats)
205
+
206
+ # validate the output buffers match (NOTE: this is assuming the output is buffer 0)
207
+ lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
208
+ import numpy as np
209
+ np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
210
+ else:
211
+ ei.run(var_vals, do_update_stats=do_update_stats)
212
+