tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py CHANGED
@@ -6,9 +6,9 @@ from tinygrad.device import Device, Buffer, Compiler
6
6
  from tinygrad.ops import MemBuffer
7
7
  from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
8
  from tinygrad.dtype import ImageDType
9
- from tinygrad.codegen.linearizer import Linearizer
9
+ from tinygrad.codegen.kernel import Kernel
10
10
  from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
11
- from tinygrad.codegen.uops import UOpGraph
11
+ from tinygrad.codegen.uopgraph import UOpGraph
12
12
  from tinygrad.tensor import Tensor
13
13
  from tinygrad.shape.symbolic import sym_infer
14
14
  from tinygrad.engine.realize import CompiledRunner
@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for a
22
22
  if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
23
23
  actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
24
24
  actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
25
+ actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
25
26
  if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
26
27
 
27
28
  def _get_test_global_size(global_size, max_global_size, var_vals):
@@ -42,25 +43,27 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
42
43
  try: car = CompiledRunner(p, precompiled=lib)
43
44
  except AssertionError: return [math.inf] * cnt
44
45
  tms = []
45
- input_bufs = [rawbufs[i] for i,_ in car.p.globals]
46
+ input_bufs = [rawbufs[i] for i in car.p.globals]
46
47
  for _ in range(cnt):
47
48
  if clear_l2:
48
- with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
+ if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
50
+ else:
51
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
52
  tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
50
- if early_stop is not None and early_stop < tms[-1]: break
53
+ if early_stop is not None and early_stop < min(tms): break
51
54
  return tms
52
55
 
53
56
  class TimeoutException(Exception): pass
54
57
  def timeout_handler(signum, frame): raise TimeoutException()
55
58
 
56
- def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
59
+ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
57
60
  signal.signal(signal.SIGALRM, timeout_handler)
58
61
  # set timeout
59
62
  signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
60
63
  try:
61
- x[1].linearize()
62
- if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
63
- p = x[1].to_program()
64
+ p = x[1].to_program(name_override="test")
65
+ assert p.uops is not None, "uop list wasn't generated?"
66
+ if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
64
67
  st = time.perf_counter()
65
68
  prog = compiler.compile(p.src)
66
69
  et = time.perf_counter() - st
@@ -85,7 +88,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
85
88
  # *** external API ***
86
89
 
87
90
  # get (scrap) buffers for timing the linearizer
88
- def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
91
+ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
89
92
  bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
90
93
  for x in lin.membufs: bufsts[x.idx].append(x)
91
94
  rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
@@ -97,7 +100,7 @@ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
97
100
  return cast(List[Buffer], rawbufs)
98
101
 
99
102
  # get dictionary of all possible actions
100
- def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
103
+ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
101
104
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
102
105
  for i,a in enumerate(actions):
103
106
  if a.axis is not None and a.op is not OptOps.TC:
@@ -115,15 +118,15 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
115
118
  return acted_lins
116
119
 
117
120
  beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
118
- def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
121
+ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
119
122
  global beam_pool
120
- key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
121
- if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
123
+ key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
124
+ if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
122
125
  ret = lin.copy()
123
126
  for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
124
127
  return ret
125
128
 
126
- beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
129
+ beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
127
130
  seen_libs = set()
128
131
 
129
132
  default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
@@ -136,20 +139,24 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
136
139
 
137
140
  try:
138
141
  rawbufs = _ensure_buffer_alloc(rawbufs)
139
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
142
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
140
143
  exiting, st = False, time.perf_counter()
141
144
  dev = Device[lin.opts.device]
142
145
  while not exiting:
143
- acted_lins: List[Linearizer] = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
144
- timed_lins: List[Tuple[Linearizer, float]] = []
146
+ acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
147
+ timed_lins: List[Tuple[Kernel, float]] = []
145
148
  _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
149
+ least_compute_ops = math.inf
146
150
  for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
147
151
  if proc is None: continue
148
152
  p, lib, compile_et = proc
149
153
  if lib in seen_libs: continue
154
+ # filter out kernels that use 1000x more compute than the smallest
155
+ least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
156
+ if least_compute_ops*1000 < this_compute_ops: continue
150
157
  #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
151
158
  seen_libs.add(lib)
152
- try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
159
+ try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
153
160
  except RuntimeError: continue # for runtime issues
154
161
  timed_lins.append((acted_lins[i], min(tms)))
155
162
  if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
@@ -181,8 +188,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
181
188
  assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
182
189
  return ret[1]
183
190
 
184
- def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
185
- key = {"ast": lin.ast[0].key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
191
+ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
192
+ key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
186
193
  "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
187
194
  if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
188
195
 
@@ -190,7 +197,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
190
197
  assert dev.compiler is not None
191
198
 
192
199
  rawbufs = _ensure_buffer_alloc(rawbufs)
193
- var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
200
+ var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
194
201
  p = lin.to_program()
195
202
  tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
196
203
  max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
tinygrad/function.py CHANGED
@@ -106,6 +106,15 @@ class Neq(Function):
106
106
  class Xor(Function):
107
107
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
108
108
 
109
+ class BitwiseAnd(Function):
110
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
111
+
112
+ class BitwiseOr(Function):
113
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
114
+
115
+ class Threefry(Function):
116
+ def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed)
117
+
109
118
  class Add(Function):
110
119
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
111
120
 
tinygrad/helpers.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
3
- import itertools, urllib.request, subprocess, shutil, math, json
3
+ import itertools, urllib.request, subprocess, shutil, math, json, contextvars
4
+ from dataclasses import dataclass
4
5
  from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
5
6
  if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
6
7
  from typing_extensions import TypeGuard
@@ -22,9 +23,11 @@ def argfix(*x):
22
23
  return tuple(x[0])
23
24
  return x
24
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
25
- def all_same(items:List[T]): return all(x == items[0] for x in items)
26
+ def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
26
27
  def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
27
28
  def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
+ def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
+ def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
28
31
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
29
32
  def ansilen(s:str): return len(ansistrip(s))
30
33
  def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
@@ -33,13 +36,15 @@ def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(s
33
36
  def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
34
37
  def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
35
38
  def round_up(num, amt:int): return (num+amt-1)//amt * amt
39
+ def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
40
+ def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
36
41
  def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
37
42
  assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
38
43
  return {k:v for d in ds for k,v in d.items()}
39
- def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
44
+ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
40
45
  a:List[T] = []
41
46
  b:List[T] = []
42
- for s in lst: (a if fxn(s) else b).append(s)
47
+ for s in itr: (a if fxn(s) else b).append(s)
43
48
  return a,b
44
49
  def unwrap(x:Optional[T]) -> T:
45
50
  assert x is not None
@@ -74,8 +79,6 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
74
79
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
75
80
  def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
76
81
 
77
- class GraphException(Exception): pass
78
-
79
82
  class Context(contextlib.ContextDecorator):
80
83
  stack: ClassVar[List[dict[str, int]]] = [{}]
81
84
  def __init__(self, **kwargs): self.kwargs = kwargs
@@ -101,9 +104,22 @@ class ContextVar:
101
104
  def __lt__(self, x): return self.value < x
102
105
 
103
106
  DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
104
- WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
107
+ WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
105
108
  GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
106
- MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
109
+ MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
110
+ USE_TC, TC_OPT, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("TRANSCENDENTAL", 1)
111
+ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
112
+ SPLIT_REDUCEOP, ARANGE_DIFF = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("ARANGE_DIFF", 0)
113
+
114
+ @dataclass(frozen=True)
115
+ class Metadata:
116
+ name: str
117
+ caller: str
118
+ backward: bool = False
119
+ def __hash__(self): return hash(self.name)
120
+ def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
121
+ def __str__(self): return self.name + (" bw" if self.backward else "")
122
+ _METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
107
123
 
108
124
  # **************** global state Counters ****************
109
125
 
@@ -147,30 +163,40 @@ class Profiling(contextlib.ContextDecorator):
147
163
  class ProfileLogger:
148
164
  writers: int = 0
149
165
  mjson: List[Dict] = []
150
- actors: Dict[str, int] = {}
151
- subactors: Dict[Tuple[str, str], int] = {}
152
- path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
166
+ actors: Dict[Union[str, Tuple[str, str]], int] = {}
153
167
 
154
- def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
168
+ def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
155
169
 
156
- def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
170
+ def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
157
171
 
158
- def __del__(self):
159
- for name,st,et,actor_name,subactor_name in self.events:
160
- if actor_name not in self.actors:
161
- self.actors[actor_name] = (pid:=len(self.actors))
162
- self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
172
+ def _ensure_actor(self, actor_name, subactor_name):
173
+ if actor_name not in self.actors:
174
+ self.actors[actor_name] = (pid:=len(self.actors))
175
+ self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
163
176
 
164
- if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
165
- self.subactors[subactor_key] = (tid:=len(self.subactors))
166
- self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
177
+ if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
178
+ self.actors[subactor_key] = (tid:=len(self.actors))
179
+ self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
167
180
 
168
- self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
181
+ return self.actors[actor_name], self.actors.get(subactor_key, -1)
182
+
183
+ def __del__(self):
184
+ # perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
185
+ for name, st, et, actor_name, subactor_name, args in self.events:
186
+ pid, tid = self._ensure_actor(actor_name,subactor_name)
187
+ args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
188
+ self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
189
+
190
+ for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
191
+ dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
192
+ pid, tid = self._ensure_actor(actor_name,subactor_name)
193
+ self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
194
+ self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
169
195
 
170
196
  ProfileLogger.writers -= 1
171
- if ProfileLogger.writers == 0:
172
- with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
173
- print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
197
+ if ProfileLogger.writers == 0 and len(self.mjson) > 0:
198
+ with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
199
+ print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
174
200
 
175
201
  # *** universal database cache ***
176
202
 
@@ -184,7 +210,10 @@ def db_connection():
184
210
  global _db_connection
185
211
  if _db_connection is None:
186
212
  os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
187
- _db_connection = sqlite3.connect(CACHEDB)
213
+ _db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
214
+ # another connection has set it already or is in the process of setting it
215
+ # that connection will lock the database
216
+ with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
188
217
  if DEBUG >= 7: _db_connection.set_trace_callback(print)
189
218
  return _db_connection
190
219
 
@@ -239,7 +268,7 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
239
268
  with urllib.request.urlopen(url, timeout=10) as r:
240
269
  assert r.status == 200
241
270
  total_length = int(r.headers.get('content-length', 0))
242
- progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
271
+ progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
243
272
  (path := fp.parent).mkdir(parents=True, exist_ok=True)
244
273
  with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
245
274
  while chunk := r.read(16384): progress_bar.update(f.write(chunk))
@@ -277,34 +306,43 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
277
306
  def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
278
307
  def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
279
308
 
309
+ # *** tqdm
310
+
280
311
  class tqdm:
281
- def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
312
+ def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
282
313
  self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
283
- self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
314
+ self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
284
315
  self.update(0)
285
316
  def __iter__(self):
286
- try:
287
- for item in self.iter:
288
- yield item
289
- self.update(1)
290
- finally: self.update(close=True)
317
+ for item in self.iter:
318
+ yield item
319
+ self.update(1)
320
+ self.update(close=True)
291
321
  def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
292
322
  def update(self, n:int=0, close:bool=False):
293
323
  self.n, self.i = self.n+n, self.i+1
294
- if (self.i % self.skip != 0 and not close) or self.dis: return
295
- prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
296
- if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
297
- def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
298
- def scl(x): return x/1000**int(math.log(x,1000))
299
- def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
300
- if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
301
- else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
302
- it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
303
- if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
304
- else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
305
- sz = max(term-5-len(suf)-len(self.desc), 1)
306
- bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
307
- print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
324
+ if self.dis or (not close and self.i % self.skip != 0): return
325
+ prog, dur, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
326
+ if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1)
327
+ def fmt(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
328
+ def fn(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
329
+ unit_text = f'{fn(self.n)}{f"/{fn(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
330
+ it_text = (fn(self.n/dur) if self.unit_scale else f"{self.n/dur:5.2f}") if self.n else "?"
331
+ tm = f'{fmt(dur)}<{fmt(dur/prog-dur) if self.n else "?"}' if self.t else fmt(dur)
332
+ suf = f'{unit_text} [{tm}, {it_text}{self.unit}/s]'
333
+ sz = max(ncols-len(self.desc)-5-2-len(suf), 1)
334
+ bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
335
+ print(bar[:ncols+1],flush=True,end='\n'*close,file=sys.stderr)
308
336
 
309
337
  class trange(tqdm):
310
- def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
338
+ def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
339
+
340
+ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
341
+ def dfs(x:Any, cache:dict):
342
+ for s in srcfn(x) or []:
343
+ cache.setdefault(s, [len(cache), 0, False])[1] += 1
344
+ if cache[s][1] == 1: dfs(s, cache)
345
+ if cache is None: dfs(x, cache:={})
346
+ if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
347
+ cx[2], srcs = True, ('None' if srcfn(x) is None else''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
348
+ return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
tinygrad/lazy.py CHANGED
@@ -1,44 +1,44 @@
1
1
  from __future__ import annotations
2
- import math
3
- from typing import Union, Optional, Any, Tuple, List
4
- from tinygrad.dtype import dtypes, DType, ConstType
5
- from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
6
- from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
2
+ from typing import Union, Optional, Any, Tuple, List, get_args
3
+ from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
4
+ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
5
+ from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
7
6
  from tinygrad.shape.symbolic import sint, Variable
8
7
  from tinygrad.shape.shapetracker import ShapeTracker
9
8
  from tinygrad.device import Buffer
10
9
  from weakref import ref, ReferenceType, WeakValueDictionary
11
10
 
12
11
  lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
13
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
12
+ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
14
13
  base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
15
- if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
16
- if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
14
+ if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
15
+ dtype = to_dtype(dtype)
16
+ if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
17
17
 
18
18
  cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
19
  if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
20
20
 
21
- ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
21
+ ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
22
22
  if enable_cache: lazycache[cache_key] = ret
23
23
  return ret
24
24
 
25
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
25
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
26
26
  class LazyBuffer:
27
- def __init__(self, device:str, st:ShapeTracker, dtype:DType,
27
+ def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike,
28
28
  op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
29
- base:Optional[LazyBuffer]=None):
30
- self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
29
+ base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
30
+ self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
31
31
  self._base: Optional[LazyBuffer] = None
32
32
  if base is None:
33
33
  # properties on base
34
34
  self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
35
- assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
35
+ assert self.op is not MetaOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
36
36
 
37
- if self.op is LoadOps.VIEW:
37
+ if self.op is MetaOps.VIEW:
38
38
  # some LazyBuffers can be processed with only a view, no AST required
39
- self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
39
+ self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
40
40
  else:
41
- self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
41
+ self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
42
42
  self.buffer.ref(1)
43
43
  self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
44
44
  self.forced_realize = False
@@ -67,36 +67,36 @@ class LazyBuffer:
67
67
  def lbs(self) -> List[LazyBuffer]: return [self]
68
68
 
69
69
  @staticmethod
70
- def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
70
+ def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
71
71
  assert isinstance(src, tuple)
72
72
  return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
73
73
 
74
74
  def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
75
- assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
75
+ assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
76
76
  shape = self.shape if shape is None else shape
77
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
77
+ return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
78
78
 
79
79
  def is_realized(self) -> bool: return self.base.realized is not None
80
80
 
81
81
  def assign(self, x:LazyBuffer) -> LazyBuffer:
82
82
  assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
83
- return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
83
+ return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
84
84
 
85
85
  def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
86
86
 
87
87
  def contiguous(self, allow_buffer_view=True):
88
88
  if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
89
- ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
89
+ ret = self.e(MetaOps.VIEW) if allow_buffer_view and self.can_view() else self.e(MetaOps.CONTIGUOUS)
90
90
  if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
91
91
  return ret
92
92
  self.base.forced_realize = True
93
93
  return self
94
94
 
95
- def cast(self, dtype:DType, bitcast:bool=False):
95
+ def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
96
96
  if self.dtype == dtype: return self
97
97
  if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
98
98
  if self.is_unrealized_unmasked_const() and not bitcast:
99
- return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
99
+ return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype))
100
100
  new_shape = self.shape
101
101
  if bitcast and self.dtype.itemsize != dtype.itemsize:
102
102
  if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
@@ -107,26 +107,26 @@ class LazyBuffer:
107
107
  elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
108
108
  # TODO: applying this makes gpt2 slower
109
109
  return self.base.cast(dtype, bitcast)._view(self.st)
110
- cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
110
+ cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
111
111
  return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
112
112
 
113
- def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
113
+ def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
114
114
  def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
115
115
 
116
116
  def _copy(self, device:str) -> LazyBuffer:
117
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
117
+ return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
118
118
 
119
119
  def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
120
120
  # no COPY
121
121
  if self.device == device: return self
122
122
 
123
123
  # double COPY = one COPY
124
- if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
124
+ if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY:
125
125
  return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
126
126
 
127
127
  # const doesn't have to be copied (issues with disk tensor)
128
128
  if self.is_unrealized_const():
129
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
129
+ return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
130
130
 
131
131
  # if it's a shrink, do the shrink before the copy with CONTIGUOUS
132
132
  if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
@@ -134,7 +134,7 @@ class LazyBuffer:
134
134
  # copy the base and apply the shapetracker on the new device
135
135
  return self.base._copy(device)._view(self.st)
136
136
 
137
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
137
+ def e(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
138
138
  srcs: List[LazyBuffer] = []
139
139
  for s in (self,)+in_srcs:
140
140
  if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
@@ -171,13 +171,12 @@ class LazyBuffer:
171
171
  assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
172
172
  axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
173
173
  if len(axis) == 0: return self
174
- new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
175
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
174
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,))
176
175
 
177
176
  def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
178
- new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
177
+ new_shape = reduce_st(self.st, axis)
179
178
  # TODO: this logic should move to the scheduler
180
- if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
179
+ if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
181
180
 
182
181
  # const folding
183
182
  # TODO: fold this for symbolic?
@@ -185,7 +184,7 @@ class LazyBuffer:
185
184
  return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
186
185
 
187
186
  # TODO: can we split symbolic shape if the reduce axis is not symbolic?
188
- if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
187
+ if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
189
188
  prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
190
189
  return self._reduce_op(op, axis)
191
190