tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py
CHANGED
@@ -6,9 +6,9 @@ from tinygrad.device import Device, Buffer, Compiler
|
|
6
6
|
from tinygrad.ops import MemBuffer
|
7
7
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
8
8
|
from tinygrad.dtype import ImageDType
|
9
|
-
from tinygrad.codegen.
|
9
|
+
from tinygrad.codegen.kernel import Kernel
|
10
10
|
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
11
|
-
from tinygrad.codegen.
|
11
|
+
from tinygrad.codegen.uopgraph import UOpGraph
|
12
12
|
from tinygrad.tensor import Tensor
|
13
13
|
from tinygrad.shape.symbolic import sym_infer
|
14
14
|
from tinygrad.engine.realize import CompiledRunner
|
@@ -22,6 +22,7 @@ actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for a
|
|
22
22
|
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
23
23
|
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
24
24
|
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
25
|
+
actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
|
25
26
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
26
27
|
|
27
28
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
@@ -42,25 +43,27 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
|
|
42
43
|
try: car = CompiledRunner(p, precompiled=lib)
|
43
44
|
except AssertionError: return [math.inf] * cnt
|
44
45
|
tms = []
|
45
|
-
input_bufs = [rawbufs[i] for i
|
46
|
+
input_bufs = [rawbufs[i] for i in car.p.globals]
|
46
47
|
for _ in range(cnt):
|
47
48
|
if clear_l2:
|
48
|
-
|
49
|
+
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
|
50
|
+
else:
|
51
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
49
52
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
50
|
-
if early_stop is not None and early_stop < tms
|
53
|
+
if early_stop is not None and early_stop < min(tms): break
|
51
54
|
return tms
|
52
55
|
|
53
56
|
class TimeoutException(Exception): pass
|
54
57
|
def timeout_handler(signum, frame): raise TimeoutException()
|
55
58
|
|
56
|
-
def _try_compile_linearized_w_idx(x:Tuple[int,
|
59
|
+
def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
|
57
60
|
signal.signal(signal.SIGALRM, timeout_handler)
|
58
61
|
# set timeout
|
59
62
|
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
60
63
|
try:
|
61
|
-
x[1].
|
62
|
-
|
63
|
-
p
|
64
|
+
p = x[1].to_program(name_override="test")
|
65
|
+
assert p.uops is not None, "uop list wasn't generated?"
|
66
|
+
if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
64
67
|
st = time.perf_counter()
|
65
68
|
prog = compiler.compile(p.src)
|
66
69
|
et = time.perf_counter() - st
|
@@ -85,7 +88,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
|
|
85
88
|
# *** external API ***
|
86
89
|
|
87
90
|
# get (scrap) buffers for timing the linearizer
|
88
|
-
def bufs_from_lin(lin:
|
91
|
+
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
89
92
|
bufsts:DefaultDict[int, List[MemBuffer]] = defaultdict(list)
|
90
93
|
for x in lin.membufs: bufsts[x.idx].append(x)
|
91
94
|
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
@@ -97,7 +100,7 @@ def bufs_from_lin(lin:Linearizer, allocate:bool=True) -> List[Buffer]:
|
|
97
100
|
return cast(List[Buffer], rawbufs)
|
98
101
|
|
99
102
|
# get dictionary of all possible actions
|
100
|
-
def
|
103
|
+
def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
101
104
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
102
105
|
for i,a in enumerate(actions):
|
103
106
|
if a.axis is not None and a.op is not OptOps.TC:
|
@@ -115,15 +118,15 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
|
115
118
|
return acted_lins
|
116
119
|
|
117
120
|
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
118
|
-
def beam_search(lin:
|
121
|
+
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
|
119
122
|
global beam_pool
|
120
|
-
key = {"ast": lin.ast
|
121
|
-
if not
|
123
|
+
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
124
|
+
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
122
125
|
ret = lin.copy()
|
123
126
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
124
127
|
return ret
|
125
128
|
|
126
|
-
beam: List[Tuple[
|
129
|
+
beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
|
127
130
|
seen_libs = set()
|
128
131
|
|
129
132
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
|
@@ -136,20 +139,24 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
|
136
139
|
|
137
140
|
try:
|
138
141
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
139
|
-
var_vals = {k:(k.max+k.min)//2 for k in lin.ast
|
142
|
+
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
140
143
|
exiting, st = False, time.perf_counter()
|
141
144
|
dev = Device[lin.opts.device]
|
142
145
|
while not exiting:
|
143
|
-
acted_lins: List[
|
144
|
-
timed_lins: List[Tuple[
|
146
|
+
acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
147
|
+
timed_lins: List[Tuple[Kernel, float]] = []
|
145
148
|
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
149
|
+
least_compute_ops = math.inf
|
146
150
|
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
147
151
|
if proc is None: continue
|
148
152
|
p, lib, compile_et = proc
|
149
153
|
if lib in seen_libs: continue
|
154
|
+
# filter out kernels that use 1000x more compute than the smallest
|
155
|
+
least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
|
156
|
+
if least_compute_ops*1000 < this_compute_ops: continue
|
150
157
|
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
151
158
|
seen_libs.add(lib)
|
152
|
-
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
159
|
+
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
153
160
|
except RuntimeError: continue # for runtime issues
|
154
161
|
timed_lins.append((acted_lins[i], min(tms)))
|
155
162
|
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(UOpGraph, p.uops).uops):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
@@ -181,8 +188,8 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
|
|
181
188
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
182
189
|
return ret[1]
|
183
190
|
|
184
|
-
def time_linearizer(lin:
|
185
|
-
key = {"ast": lin.ast
|
191
|
+
def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
192
|
+
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
186
193
|
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
187
194
|
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
188
195
|
|
@@ -190,7 +197,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
|
190
197
|
assert dev.compiler is not None
|
191
198
|
|
192
199
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
193
|
-
var_vals = {k:(k.max+k.min)//2 for k in lin.ast
|
200
|
+
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
194
201
|
p = lin.to_program()
|
195
202
|
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
196
203
|
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
tinygrad/function.py
CHANGED
@@ -106,6 +106,15 @@ class Neq(Function):
|
|
106
106
|
class Xor(Function):
|
107
107
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
108
108
|
|
109
|
+
class BitwiseAnd(Function):
|
110
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.AND, y)
|
111
|
+
|
112
|
+
class BitwiseOr(Function):
|
113
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.OR, y)
|
114
|
+
|
115
|
+
class Threefry(Function):
|
116
|
+
def forward(self, x:LazyBuffer, seed:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.THREEFRY, seed)
|
117
|
+
|
109
118
|
class Add(Function):
|
110
119
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
|
111
120
|
|
tinygrad/helpers.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
|
3
|
-
import itertools, urllib.request, subprocess, shutil, math, json
|
3
|
+
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
|
4
|
+
from dataclasses import dataclass
|
4
5
|
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
5
6
|
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
6
7
|
from typing_extensions import TypeGuard
|
@@ -22,9 +23,11 @@ def argfix(*x):
|
|
22
23
|
return tuple(x[0])
|
23
24
|
return x
|
24
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
25
|
-
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
26
|
+
def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
|
26
27
|
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
27
28
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
|
+
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
|
+
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
28
31
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
29
32
|
def ansilen(s:str): return len(ansistrip(s))
|
30
33
|
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
@@ -33,13 +36,15 @@ def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(s
|
|
33
36
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
34
37
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
35
38
|
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
39
|
+
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
|
40
|
+
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
|
36
41
|
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
37
42
|
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
|
38
43
|
return {k:v for d in ds for k,v in d.items()}
|
39
|
-
def partition(
|
44
|
+
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
40
45
|
a:List[T] = []
|
41
46
|
b:List[T] = []
|
42
|
-
for s in
|
47
|
+
for s in itr: (a if fxn(s) else b).append(s)
|
43
48
|
return a,b
|
44
49
|
def unwrap(x:Optional[T]) -> T:
|
45
50
|
assert x is not None
|
@@ -74,8 +79,6 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
|
|
74
79
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
75
80
|
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
76
81
|
|
77
|
-
class GraphException(Exception): pass
|
78
|
-
|
79
82
|
class Context(contextlib.ContextDecorator):
|
80
83
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
81
84
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
@@ -101,9 +104,22 @@ class ContextVar:
|
|
101
104
|
def __lt__(self, x): return self.value < x
|
102
105
|
|
103
106
|
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
104
|
-
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
|
107
|
+
WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
105
108
|
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
106
|
-
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
|
109
|
+
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
110
|
+
USE_TC, TC_OPT, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("TRANSCENDENTAL", 1)
|
111
|
+
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
112
|
+
SPLIT_REDUCEOP, ARANGE_DIFF = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("ARANGE_DIFF", 0)
|
113
|
+
|
114
|
+
@dataclass(frozen=True)
|
115
|
+
class Metadata:
|
116
|
+
name: str
|
117
|
+
caller: str
|
118
|
+
backward: bool = False
|
119
|
+
def __hash__(self): return hash(self.name)
|
120
|
+
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
|
121
|
+
def __str__(self): return self.name + (" bw" if self.backward else "")
|
122
|
+
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
107
123
|
|
108
124
|
# **************** global state Counters ****************
|
109
125
|
|
@@ -147,30 +163,40 @@ class Profiling(contextlib.ContextDecorator):
|
|
147
163
|
class ProfileLogger:
|
148
164
|
writers: int = 0
|
149
165
|
mjson: List[Dict] = []
|
150
|
-
actors: Dict[str, int] = {}
|
151
|
-
subactors: Dict[Tuple[str, str], int] = {}
|
152
|
-
path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
|
166
|
+
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
153
167
|
|
154
|
-
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
168
|
+
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
155
169
|
|
156
|
-
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
170
|
+
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
157
171
|
|
158
|
-
def
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
172
|
+
def _ensure_actor(self, actor_name, subactor_name):
|
173
|
+
if actor_name not in self.actors:
|
174
|
+
self.actors[actor_name] = (pid:=len(self.actors))
|
175
|
+
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
163
176
|
|
164
|
-
|
165
|
-
|
166
|
-
|
177
|
+
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
178
|
+
self.actors[subactor_key] = (tid:=len(self.actors))
|
179
|
+
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
167
180
|
|
168
|
-
|
181
|
+
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
182
|
+
|
183
|
+
def __del__(self):
|
184
|
+
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
185
|
+
for name, st, et, actor_name, subactor_name, args in self.events:
|
186
|
+
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
187
|
+
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
188
|
+
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
189
|
+
|
190
|
+
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
191
|
+
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
192
|
+
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
193
|
+
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
194
|
+
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
169
195
|
|
170
196
|
ProfileLogger.writers -= 1
|
171
|
-
if ProfileLogger.writers == 0:
|
172
|
-
with open(
|
173
|
-
print(f"Saved profile to {
|
197
|
+
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
198
|
+
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
199
|
+
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
174
200
|
|
175
201
|
# *** universal database cache ***
|
176
202
|
|
@@ -184,7 +210,10 @@ def db_connection():
|
|
184
210
|
global _db_connection
|
185
211
|
if _db_connection is None:
|
186
212
|
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
|
187
|
-
_db_connection = sqlite3.connect(CACHEDB)
|
213
|
+
_db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
|
214
|
+
# another connection has set it already or is in the process of setting it
|
215
|
+
# that connection will lock the database
|
216
|
+
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
|
188
217
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
189
218
|
return _db_connection
|
190
219
|
|
@@ -239,7 +268,7 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
|
|
239
268
|
with urllib.request.urlopen(url, timeout=10) as r:
|
240
269
|
assert r.status == 200
|
241
270
|
total_length = int(r.headers.get('content-length', 0))
|
242
|
-
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}
|
271
|
+
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
243
272
|
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
244
273
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
245
274
|
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
@@ -277,34 +306,43 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|
277
306
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
278
307
|
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
279
308
|
|
309
|
+
# *** tqdm
|
310
|
+
|
280
311
|
class tqdm:
|
281
|
-
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int
|
312
|
+
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
|
282
313
|
self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
|
283
|
-
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1,
|
314
|
+
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
284
315
|
self.update(0)
|
285
316
|
def __iter__(self):
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
finally: self.update(close=True)
|
317
|
+
for item in self.iter:
|
318
|
+
yield item
|
319
|
+
self.update(1)
|
320
|
+
self.update(close=True)
|
291
321
|
def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
|
292
322
|
def update(self, n:int=0, close:bool=False):
|
293
323
|
self.n, self.i = self.n+n, self.i+1
|
294
|
-
if (self.i % self.skip != 0
|
295
|
-
prog, dur,
|
296
|
-
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1)
|
297
|
-
def fmt(t): return ':'.join(
|
298
|
-
def
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
307
|
-
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
324
|
+
if self.dis or (not close and self.i % self.skip != 0): return
|
325
|
+
prog, dur, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
326
|
+
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1)
|
327
|
+
def fmt(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
|
328
|
+
def fn(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
|
329
|
+
unit_text = f'{fn(self.n)}{f"/{fn(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
|
330
|
+
it_text = (fn(self.n/dur) if self.unit_scale else f"{self.n/dur:5.2f}") if self.n else "?"
|
331
|
+
tm = f'{fmt(dur)}<{fmt(dur/prog-dur) if self.n else "?"}' if self.t else fmt(dur)
|
332
|
+
suf = f'{unit_text} [{tm}, {it_text}{self.unit}/s]'
|
333
|
+
sz = max(ncols-len(self.desc)-5-2-len(suf), 1)
|
334
|
+
bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
|
335
|
+
print(bar[:ncols+1],flush=True,end='\n'*close,file=sys.stderr)
|
308
336
|
|
309
337
|
class trange(tqdm):
|
310
|
-
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
338
|
+
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
339
|
+
|
340
|
+
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
341
|
+
def dfs(x:Any, cache:dict):
|
342
|
+
for s in srcfn(x) or []:
|
343
|
+
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
344
|
+
if cache[s][1] == 1: dfs(s, cache)
|
345
|
+
if cache is None: dfs(x, cache:={})
|
346
|
+
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
347
|
+
cx[2], srcs = True, ('None' if srcfn(x) is None else''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
348
|
+
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
tinygrad/lazy.py
CHANGED
@@ -1,44 +1,44 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
2
|
+
from typing import Union, Optional, Any, Tuple, List, get_args
|
3
|
+
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
4
|
+
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
5
|
+
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
|
7
6
|
from tinygrad.shape.symbolic import sint, Variable
|
8
7
|
from tinygrad.shape.shapetracker import ShapeTracker
|
9
8
|
from tinygrad.device import Buffer
|
10
9
|
from weakref import ref, ReferenceType, WeakValueDictionary
|
11
10
|
|
12
11
|
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
13
|
-
def create_lazybuffer(device:str, st:ShapeTracker, dtype:
|
12
|
+
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DTypeLike, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
14
13
|
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
15
|
-
if st.size == 0: op, arg, srcs, base =
|
16
|
-
|
14
|
+
if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None
|
15
|
+
dtype = to_dtype(dtype)
|
16
|
+
if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
|
17
17
|
|
18
18
|
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
19
|
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
20
20
|
|
21
|
-
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
|
21
|
+
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
|
22
22
|
if enable_cache: lazycache[cache_key] = ret
|
23
23
|
return ret
|
24
24
|
|
25
|
-
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
|
25
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "DISK"}
|
26
26
|
class LazyBuffer:
|
27
|
-
def __init__(self, device:str, st:ShapeTracker, dtype:
|
27
|
+
def __init__(self, device:str, st:ShapeTracker, dtype:DTypeLike,
|
28
28
|
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
29
|
-
base:Optional[LazyBuffer]=None):
|
30
|
-
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
29
|
+
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
30
|
+
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
31
31
|
self._base: Optional[LazyBuffer] = None
|
32
32
|
if base is None:
|
33
33
|
# properties on base
|
34
34
|
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
35
|
-
assert self.op is not
|
35
|
+
assert self.op is not MetaOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
|
36
36
|
|
37
|
-
if self.op is
|
37
|
+
if self.op is MetaOps.VIEW:
|
38
38
|
# some LazyBuffers can be processed with only a view, no AST required
|
39
|
-
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
39
|
+
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
40
40
|
else:
|
41
|
-
self.buffer = srcs[1].base.buffer if self.op is
|
41
|
+
self.buffer = srcs[1].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype)
|
42
42
|
self.buffer.ref(1)
|
43
43
|
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
44
44
|
self.forced_realize = False
|
@@ -67,36 +67,36 @@ class LazyBuffer:
|
|
67
67
|
def lbs(self) -> List[LazyBuffer]: return [self]
|
68
68
|
|
69
69
|
@staticmethod
|
70
|
-
def
|
70
|
+
def metaop(op, shape:Tuple[sint,...], dtype:DTypeLike, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
71
71
|
assert isinstance(src, tuple)
|
72
72
|
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
73
73
|
|
74
74
|
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
75
|
-
assert isinstance(val, (
|
75
|
+
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
76
76
|
shape = self.shape if shape is None else shape
|
77
|
-
return LazyBuffer.
|
77
|
+
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
78
78
|
|
79
79
|
def is_realized(self) -> bool: return self.base.realized is not None
|
80
80
|
|
81
81
|
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
82
82
|
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
83
|
-
return LazyBuffer.
|
83
|
+
return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
|
84
84
|
|
85
85
|
def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
|
86
86
|
|
87
87
|
def contiguous(self, allow_buffer_view=True):
|
88
88
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
89
|
-
ret = self.e(
|
89
|
+
ret = self.e(MetaOps.VIEW) if allow_buffer_view and self.can_view() else self.e(MetaOps.CONTIGUOUS)
|
90
90
|
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
91
91
|
return ret
|
92
92
|
self.base.forced_realize = True
|
93
93
|
return self
|
94
94
|
|
95
|
-
def cast(self, dtype:DType, bitcast:bool=False):
|
95
|
+
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
96
96
|
if self.dtype == dtype: return self
|
97
97
|
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
98
98
|
if self.is_unrealized_unmasked_const() and not bitcast:
|
99
|
-
return create_lazybuffer(self.device, self.st, dtype,
|
99
|
+
return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
100
100
|
new_shape = self.shape
|
101
101
|
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
102
102
|
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
@@ -107,26 +107,26 @@ class LazyBuffer:
|
|
107
107
|
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
108
108
|
# TODO: applying this makes gpt2 slower
|
109
109
|
return self.base.cast(dtype, bitcast)._view(self.st)
|
110
|
-
cast_op: Union[
|
110
|
+
cast_op: Union[MetaOps, UnaryOps] = (MetaOps.VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
|
111
111
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
112
112
|
|
113
|
-
def is_unrealized_const(self): return self.base.realized is None and self.base.op is
|
113
|
+
def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, Variable)
|
114
114
|
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
115
115
|
|
116
116
|
def _copy(self, device:str) -> LazyBuffer:
|
117
|
-
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype,
|
117
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
118
118
|
|
119
119
|
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
120
120
|
# no COPY
|
121
121
|
if self.device == device: return self
|
122
122
|
|
123
123
|
# double COPY = one COPY
|
124
|
-
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is
|
124
|
+
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY:
|
125
125
|
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
126
126
|
|
127
127
|
# const doesn't have to be copied (issues with disk tensor)
|
128
128
|
if self.is_unrealized_const():
|
129
|
-
return LazyBuffer.
|
129
|
+
return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
130
130
|
|
131
131
|
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
132
132
|
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
@@ -134,7 +134,7 @@ class LazyBuffer:
|
|
134
134
|
# copy the base and apply the shapetracker on the new device
|
135
135
|
return self.base._copy(device)._view(self.st)
|
136
136
|
|
137
|
-
def e(self, op:Union[
|
137
|
+
def e(self, op:Union[MetaOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
138
138
|
srcs: List[LazyBuffer] = []
|
139
139
|
for s in (self,)+in_srcs:
|
140
140
|
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
@@ -171,13 +171,12 @@ class LazyBuffer:
|
|
171
171
|
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
172
172
|
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
|
173
173
|
if len(axis) == 0: return self
|
174
|
-
|
175
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
174
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(reduce_st(self.st, axis)), self.dtype, op, axis, (self,))
|
176
175
|
|
177
176
|
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
178
|
-
new_shape =
|
177
|
+
new_shape = reduce_st(self.st, axis)
|
179
178
|
# TODO: this logic should move to the scheduler
|
180
|
-
if self.
|
179
|
+
if 0 in self.shape and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: dtypes.min(self.dtype)}[op], new_shape)
|
181
180
|
|
182
181
|
# const folding
|
183
182
|
# TODO: fold this for symbolic?
|
@@ -185,7 +184,7 @@ class LazyBuffer:
|
|
185
184
|
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
186
185
|
|
187
186
|
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
188
|
-
if not
|
187
|
+
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
189
188
|
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
190
189
|
return self._reduce_op(op, axis)
|
191
190
|
|