tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/jit.py
CHANGED
@@ -1,22 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
|
3
|
-
import functools,
|
3
|
+
import functools, collections
|
4
4
|
from tinygrad.tensor import Tensor
|
5
|
-
from tinygrad.lazy import LazyBuffer
|
6
|
-
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context,
|
5
|
+
from tinygrad.engine.lazy import LazyBuffer
|
6
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
|
7
7
|
from tinygrad.device import Buffer, Compiled, Device
|
8
8
|
from tinygrad.dtype import DType
|
9
|
+
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
9
10
|
from tinygrad.shape.shapetracker import ShapeTracker
|
10
|
-
from tinygrad.shape.symbolic import Variable, sint
|
11
11
|
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
12
|
-
from tinygrad.engine.
|
12
|
+
from tinygrad.engine.memory import _internal_memory_planner
|
13
13
|
from tinygrad.nn.state import get_parameters
|
14
|
+
from dataclasses import dataclass
|
14
15
|
from weakref import WeakKeyDictionary
|
15
16
|
|
16
|
-
|
17
|
+
class GraphException(Exception): pass
|
18
|
+
|
19
|
+
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], max_batch_size=0) -> List[ExecItem]:
|
17
20
|
# Split JIT cache into batches for faster graph execution.
|
18
21
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
19
|
-
max_batch_size = getenv("JIT_BATCH_SIZE", 32)
|
20
22
|
graphed_jit_cache: List[ExecItem] = []
|
21
23
|
current_batch: List[ExecItem] = []
|
22
24
|
current_device: Optional[Compiled] = None
|
@@ -30,10 +32,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
30
32
|
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
31
33
|
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
|
32
34
|
max_batch_size *= 2
|
33
|
-
if DEBUG >= 2: print(f"
|
35
|
+
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
34
36
|
except GraphException as e:
|
35
37
|
graphed_jit_cache.extend(current_batch)
|
36
|
-
if DEBUG >= 2: print(f"
|
38
|
+
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
37
39
|
current_batch = []
|
38
40
|
current_device = None
|
39
41
|
|
@@ -44,11 +46,11 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
44
46
|
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
45
47
|
ji_graph_dev = Device[ji.bufs[0].device]
|
46
48
|
|
47
|
-
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
|
49
|
+
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
|
48
50
|
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
49
51
|
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
|
50
|
-
type(ji_graph_dev)
|
51
|
-
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
|
52
|
+
type(ji_graph_dev) is type(current_device))
|
53
|
+
can_extend_graph_batch = can_be_graphed and (max_batch_size == 0 or len(current_batch) < max_batch_size) and can_share_graph
|
52
54
|
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
53
55
|
|
54
56
|
if can_be_graphed: current_batch.append(ji)
|
@@ -70,129 +72,224 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
|
|
70
72
|
class GraphRunner(Runner): # pylint: disable=abstract-method
|
71
73
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
72
74
|
self.jit_cache = jit_cache
|
73
|
-
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
74
|
-
self.
|
75
|
-
self.
|
75
|
+
self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
76
|
+
self.var_vals_replace:Dict[int, List[int]] = {}
|
77
|
+
self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
|
78
|
+
|
76
79
|
op_estimate: sint = 0
|
77
80
|
mem_estimate: sint = 0
|
81
|
+
lds_estimate: sint = 0
|
82
|
+
|
83
|
+
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
84
|
+
|
85
|
+
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
86
|
+
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
87
|
+
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
88
|
+
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
89
|
+
|
78
90
|
for j,ji in enumerate(jit_cache):
|
79
91
|
op_estimate += ji.prg.op_estimate
|
80
92
|
mem_estimate += ji.prg.mem_estimate
|
93
|
+
lds_estimate += ji.prg.lds_estimate
|
81
94
|
if isinstance(ji.prg, CompiledRunner):
|
82
|
-
if ji.prg.p.vars: self.
|
83
|
-
if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
|
84
|
-
self.jc_idx_with_updatable_launch_dims.append(j)
|
85
|
-
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
|
86
|
-
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
|
95
|
+
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
87
96
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
97
|
+
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
98
|
+
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
99
|
+
|
100
|
+
# used in MultiGraphRunner. the ints are id() of _bufs
|
101
|
+
self.w_dependency_map: Dict[int, Any] = {}
|
102
|
+
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
|
103
|
+
|
104
|
+
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
105
|
+
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
106
|
+
|
107
|
+
def updated_vars(self, var_vals: Dict[Variable, int]):
|
108
|
+
vals = [var_vals[v] for v in self.vars]
|
109
|
+
for j, vidxs in self.var_vals_replace.items():
|
110
|
+
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
93
111
|
|
94
|
-
def
|
112
|
+
def updated_launch_dims(self, var_vals: Dict[Variable, int]):
|
113
|
+
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
114
|
+
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
|
115
|
+
|
116
|
+
def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
|
95
117
|
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
96
118
|
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
97
119
|
wait_nodes = []
|
98
120
|
|
99
|
-
for rawbuf in
|
121
|
+
for i,rawbuf in enumerate(rawbufs):
|
100
122
|
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
101
|
-
|
102
|
-
|
123
|
+
if i in write:
|
124
|
+
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
125
|
+
self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
126
|
+
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
103
127
|
|
104
|
-
for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
105
|
-
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
106
128
|
return list({id(x):x for x in wait_nodes}.values())
|
107
129
|
|
130
|
+
# a marker for your graph supporting multiple devices of the same type
|
131
|
+
class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
|
132
|
+
|
108
133
|
ReturnType = TypeVar('ReturnType')
|
109
|
-
|
134
|
+
@dataclass
|
135
|
+
class CapturedJit(Generic[ReturnType]):
|
136
|
+
ret: Any # includes the Tensors or any other returned object
|
137
|
+
jit_cache: List[ExecItem]
|
138
|
+
input_replace: Dict[Tuple[int, int], int]
|
139
|
+
extra_view_inputs: List[Tuple[int, int, str, int, DType]]
|
140
|
+
expected_names: List[Union[int, str]]
|
141
|
+
expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
|
142
|
+
|
143
|
+
def __reduce__(self):
|
144
|
+
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
145
|
+
self.expected_names, self.expected_st_vars_dtype_device)
|
146
|
+
|
147
|
+
def __post_init__(self):
|
148
|
+
self._jit_cache: List[ExecItem] = self.jit_cache
|
149
|
+
self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
|
150
|
+
self._graphed = False
|
151
|
+
self._clear_inputs()
|
152
|
+
|
153
|
+
def _clear_inputs(self):
|
154
|
+
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
155
|
+
|
156
|
+
# jit exec
|
157
|
+
def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
|
158
|
+
# assign inputs
|
159
|
+
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
160
|
+
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
161
|
+
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
162
|
+
|
163
|
+
# Condense the items into a graph executor.
|
164
|
+
if JIT < 2 and not self._graphed:
|
165
|
+
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
|
166
|
+
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
167
|
+
self._graphed = True
|
168
|
+
|
169
|
+
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
170
|
+
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
171
|
+
self._clear_inputs()
|
172
|
+
return self.ret
|
173
|
+
|
174
|
+
def _prepare_jit_inputs(args, kwargs):
|
175
|
+
input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
176
|
+
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
177
|
+
if tensors: Tensor.realize(*tensors)
|
178
|
+
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
|
179
|
+
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
180
|
+
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
181
|
+
st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
182
|
+
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
183
|
+
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
184
|
+
return input_buffers, var_vals, names, st_vars_dtype_device
|
185
|
+
|
110
186
|
class TinyJit(Generic[ReturnType]):
|
111
|
-
def __init__(self, fxn:Callable[..., ReturnType]):
|
187
|
+
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
|
188
|
+
assert fxn or captured, "need either a function or a CapturedJit"
|
112
189
|
self.fxn = fxn
|
113
|
-
self.
|
190
|
+
self.captured: Optional[CapturedJit] = captured
|
191
|
+
self.cnt: int = 2 if self.fxn is None else 0
|
192
|
+
self.prune = prune
|
114
193
|
|
115
194
|
def add_buffer(self, b:Buffer) -> Buffer:
|
116
|
-
if found:=self.
|
195
|
+
if found:=self._buffer_replace.get(b, None): return found
|
117
196
|
if b.is_allocated() or b.lb_refcount > 0: return b
|
118
197
|
if b._base is not None:
|
119
|
-
self.
|
198
|
+
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
|
120
199
|
else:
|
121
|
-
self.
|
200
|
+
self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
|
122
201
|
return ret
|
123
202
|
|
124
203
|
def add(self, ei:ExecItem):
|
125
|
-
self.
|
204
|
+
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
126
205
|
|
127
206
|
def reset(self):
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
self.
|
131
|
-
|
132
|
-
|
207
|
+
assert self.fxn is not None, "can't reset without function"
|
208
|
+
self.cnt = 0
|
209
|
+
self.captured = None
|
210
|
+
|
211
|
+
def __reduce__(self):
|
212
|
+
assert self.captured is not None, "can't pickle an uncaptured JIT"
|
213
|
+
return self.__class__, (None, self.captured)
|
214
|
+
|
215
|
+
# keep legacy code working
|
216
|
+
@property
|
217
|
+
def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
218
|
+
@property
|
219
|
+
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
133
220
|
|
134
221
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
135
222
|
|
136
223
|
def __call__(self, *args, **kwargs) -> ReturnType:
|
137
|
-
|
138
|
-
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
|
139
|
-
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
|
140
|
-
names: List[Union[int, str]] = [name for name,_ in input_tensors]
|
141
|
-
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
|
142
|
-
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
143
|
-
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
144
|
-
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
145
|
-
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
146
|
-
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
|
147
|
-
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
224
|
+
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
|
148
225
|
if not JIT or self.cnt == 0:
|
149
|
-
if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
|
150
226
|
# jit ignore
|
151
|
-
|
152
|
-
|
153
|
-
|
227
|
+
assert self.fxn is not None
|
228
|
+
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
|
229
|
+
ret = self.fxn(*args, **kwargs)
|
230
|
+
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
154
231
|
elif self.cnt == 1:
|
155
232
|
# jit capture
|
156
|
-
self.
|
157
|
-
|
158
|
-
|
233
|
+
assert self.fxn is not None
|
234
|
+
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
235
|
+
self._jit_cache: List[ExecItem] = []
|
236
|
+
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
237
|
+
# TODO: should we always disable the memory planner here? it must be off for prune
|
238
|
+
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
159
239
|
capturing.append(self)
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
240
|
+
try:
|
241
|
+
ret = self.fxn(*args, **kwargs)
|
242
|
+
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
|
243
|
+
except Exception as e: raise e
|
244
|
+
finally: capturing.clear()
|
245
|
+
jit_cache = self._jit_cache
|
246
|
+
del self._buffer_replace, self._jit_cache
|
247
|
+
assert len(jit_cache), "didn't JIT anything!"
|
248
|
+
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
|
166
249
|
|
167
250
|
# track inputs that are views of buffers
|
168
|
-
|
251
|
+
# TODO: eventually expected_buffers should live in ExecItem
|
252
|
+
extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
|
253
|
+
for item in jit_cache:
|
169
254
|
for b in item.bufs:
|
170
255
|
if b is not None and b._base is not None and b._base in input_buffers:
|
171
256
|
input_buffers.append(b)
|
172
|
-
|
257
|
+
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
258
|
+
|
259
|
+
# prune independent kernels (optional)
|
260
|
+
if self.prune:
|
261
|
+
depends = set(input_buffers)
|
262
|
+
for ei in jit_cache:
|
263
|
+
if any(b in depends for b in ei.bufs):
|
264
|
+
if isinstance(ei.prg, CompiledRunner):
|
265
|
+
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
266
|
+
pruned, onetime = partition(jit_cache,
|
267
|
+
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
268
|
+
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
269
|
+
# run the onetime kernels here
|
270
|
+
for ei in onetime:
|
271
|
+
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
272
|
+
ei.run(var_vals, jit=True)
|
273
|
+
jit_cache = pruned
|
173
274
|
|
174
275
|
# memory planning (optional)
|
175
|
-
|
176
|
-
|
276
|
+
# Exclude buffers involved in transfer ops to preserve parallelism.
|
277
|
+
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
278
|
+
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
279
|
+
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
177
280
|
|
178
|
-
|
179
|
-
if
|
281
|
+
input_replace = get_input_replace(jit_cache, input_buffers)
|
282
|
+
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
180
283
|
|
181
|
-
|
182
|
-
|
284
|
+
# set this for next run
|
285
|
+
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
|
183
286
|
elif self.cnt >= 2:
|
184
287
|
# jit exec
|
185
|
-
assert self.
|
186
|
-
assert self.
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
|
191
|
-
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
192
|
-
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
193
|
-
|
194
|
-
# clear jit inputs
|
195
|
-
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
288
|
+
assert self.captured is not None
|
289
|
+
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
|
290
|
+
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
|
291
|
+
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
|
292
|
+
ret = self.captured(input_buffers, var_vals)
|
196
293
|
|
197
294
|
self.cnt += 1
|
198
|
-
return
|
295
|
+
return ret
|
@@ -1,44 +1,44 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.ops import
|
7
|
-
from tinygrad.shape.symbolic import sint, Variable
|
2
|
+
from typing import Optional, Any, Tuple, List, get_args
|
3
|
+
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
4
|
+
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
5
|
+
from tinygrad.ops import exec_alu, python_alu
|
6
|
+
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
8
7
|
from tinygrad.shape.shapetracker import ShapeTracker
|
9
8
|
from tinygrad.device import Buffer
|
10
9
|
from weakref import ref, ReferenceType, WeakValueDictionary
|
11
10
|
|
12
11
|
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
13
|
-
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[
|
14
|
-
base:Optional[LazyBuffer]=None, enable_cache=bool(
|
15
|
-
if st.size == 0: op, arg, srcs, base =
|
16
|
-
|
12
|
+
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
13
|
+
base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
|
14
|
+
if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
|
15
|
+
dtype = to_dtype(dtype)
|
16
|
+
if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
|
17
17
|
|
18
18
|
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
19
|
-
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
19
|
+
if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
|
20
20
|
|
21
|
-
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
|
21
|
+
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
|
22
22
|
if enable_cache: lazycache[cache_key] = ret
|
23
23
|
return ret
|
24
24
|
|
25
|
-
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
|
26
|
-
class LazyBuffer:
|
25
|
+
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
26
|
+
class LazyBuffer(MathTrait):
|
27
27
|
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
28
|
-
op:Optional[
|
29
|
-
base:Optional[LazyBuffer]=None):
|
30
|
-
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
28
|
+
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
29
|
+
base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
|
30
|
+
self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
|
31
31
|
self._base: Optional[LazyBuffer] = None
|
32
32
|
if base is None:
|
33
33
|
# properties on base
|
34
|
-
self.op, self.arg, self.srcs = op, arg, srcs # this is a
|
35
|
-
assert self.op is not
|
34
|
+
self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
|
35
|
+
assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
|
36
36
|
|
37
|
-
if self.op is
|
37
|
+
if self.op is Ops.BUFFER_VIEW:
|
38
38
|
# some LazyBuffers can be processed with only a view, no AST required
|
39
|
-
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
39
|
+
self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
|
40
40
|
else:
|
41
|
-
self.buffer = srcs[
|
41
|
+
self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
|
42
42
|
self.buffer.ref(1)
|
43
43
|
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
44
44
|
self.forced_realize = False
|
@@ -51,7 +51,7 @@ class LazyBuffer:
|
|
51
51
|
if hasattr(self, 'buffer'): self.buffer.ref(-1)
|
52
52
|
|
53
53
|
def __repr__(self) -> str:
|
54
|
-
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base
|
54
|
+
return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
|
55
55
|
|
56
56
|
@property
|
57
57
|
def realized(self) -> Optional[Buffer]:
|
@@ -67,36 +67,42 @@ class LazyBuffer:
|
|
67
67
|
def lbs(self) -> List[LazyBuffer]: return [self]
|
68
68
|
|
69
69
|
@staticmethod
|
70
|
-
def
|
70
|
+
def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
|
71
71
|
assert isinstance(src, tuple)
|
72
72
|
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
73
73
|
|
74
|
-
def
|
75
|
-
|
76
|
-
|
77
|
-
return LazyBuffer.
|
74
|
+
def const_like(self, b): return self.const_with_shape(b, self.shape)
|
75
|
+
def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
|
76
|
+
assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
|
77
|
+
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
78
78
|
|
79
|
+
@property
|
79
80
|
def is_realized(self) -> bool: return self.base.realized is not None
|
80
81
|
|
81
82
|
def assign(self, x:LazyBuffer) -> LazyBuffer:
|
82
83
|
assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
|
83
|
-
|
84
|
+
assert self.is_realized, f"assign target must be realized {self}"
|
85
|
+
return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
|
86
|
+
src=(self.base, x), enable_cache=True)
|
84
87
|
|
85
|
-
def can_view(self):
|
88
|
+
def can_view(self):
|
89
|
+
return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
|
90
|
+
self.device.split(":")[0] in view_supported_devices)
|
86
91
|
|
87
92
|
def contiguous(self, allow_buffer_view=True):
|
88
93
|
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
89
|
-
ret = self.
|
94
|
+
ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
|
90
95
|
if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
|
91
96
|
return ret
|
92
97
|
self.base.forced_realize = True
|
93
98
|
return self
|
94
99
|
|
95
|
-
def
|
100
|
+
def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
|
101
|
+
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
|
96
102
|
if self.dtype == dtype: return self
|
97
103
|
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
98
104
|
if self.is_unrealized_unmasked_const() and not bitcast:
|
99
|
-
return create_lazybuffer(self.device, self.st, dtype,
|
105
|
+
return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
|
100
106
|
new_shape = self.shape
|
101
107
|
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
102
108
|
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
@@ -104,29 +110,30 @@ class LazyBuffer:
|
|
104
110
|
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
105
111
|
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
106
112
|
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
107
|
-
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self
|
113
|
+
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
108
114
|
# TODO: applying this makes gpt2 slower
|
109
115
|
return self.base.cast(dtype, bitcast)._view(self.st)
|
110
|
-
cast_op:
|
116
|
+
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
|
111
117
|
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
112
118
|
|
113
|
-
def is_unrealized_const(self): return self.base.realized is None and self.base.op is
|
119
|
+
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
114
120
|
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
115
121
|
|
116
122
|
def _copy(self, device:str) -> LazyBuffer:
|
117
|
-
|
123
|
+
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
124
|
+
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
118
125
|
|
119
|
-
def copy_to_device(self, device:str, force: bool
|
126
|
+
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
120
127
|
# no COPY
|
121
|
-
if self.device == device: return self
|
128
|
+
if self.device == device and not clone: return self
|
122
129
|
|
123
130
|
# double COPY = one COPY
|
124
|
-
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is
|
131
|
+
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
|
125
132
|
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
126
133
|
|
127
134
|
# const doesn't have to be copied (issues with disk tensor)
|
128
135
|
if self.is_unrealized_const():
|
129
|
-
return LazyBuffer.
|
136
|
+
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
130
137
|
|
131
138
|
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
132
139
|
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
@@ -134,58 +141,59 @@ class LazyBuffer:
|
|
134
141
|
# copy the base and apply the shapetracker on the new device
|
135
142
|
return self.base._copy(device)._view(self.st)
|
136
143
|
|
137
|
-
def
|
144
|
+
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
|
145
|
+
|
146
|
+
def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
|
138
147
|
srcs: List[LazyBuffer] = []
|
139
148
|
for s in (self,)+in_srcs:
|
140
149
|
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
141
150
|
srcs.append(root._view(s.base.contiguous_child[1]))
|
142
151
|
else:
|
143
152
|
srcs.append(s)
|
144
|
-
|
153
|
+
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
|
154
|
+
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
145
155
|
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
146
|
-
if op is
|
147
|
-
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
156
|
+
if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
|
148
157
|
|
149
|
-
out_dtype = dtypes.bool if op in (
|
158
|
+
out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
|
150
159
|
|
151
160
|
# const folding
|
152
161
|
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
153
|
-
return self.cast(out_dtype).
|
154
|
-
if op
|
155
|
-
if op in BinaryOps:
|
162
|
+
return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
163
|
+
if op in GroupOp.Binary:
|
156
164
|
x, y = self, in_srcs[0]
|
157
|
-
if op is
|
165
|
+
if op is Ops.ADD:
|
158
166
|
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
159
167
|
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
160
|
-
if op is
|
161
|
-
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0
|
162
|
-
|
163
|
-
|
164
|
-
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
168
|
+
if op is Ops.MUL:
|
169
|
+
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
|
170
|
+
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
|
171
|
+
if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
165
172
|
|
166
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op,
|
173
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
|
167
174
|
|
168
175
|
# *** reduce ops ***
|
169
176
|
|
170
|
-
def _reduce_op(self, op:
|
177
|
+
def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
171
178
|
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
172
|
-
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
|
179
|
+
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
173
180
|
if len(axis) == 0: return self
|
174
|
-
|
175
|
-
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
181
|
+
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
|
176
182
|
|
177
|
-
def r(self, op:
|
178
|
-
new_shape =
|
183
|
+
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
184
|
+
new_shape = self.st.reduce(axis)
|
179
185
|
# TODO: this logic should move to the scheduler
|
180
|
-
if self.
|
186
|
+
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
|
181
187
|
|
182
188
|
# const folding
|
183
189
|
# TODO: fold this for symbolic?
|
184
190
|
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
185
|
-
return self.
|
191
|
+
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
192
|
+
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
193
|
+
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
186
194
|
|
187
195
|
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
188
|
-
if not
|
196
|
+
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
189
197
|
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
190
198
|
return self._reduce_op(op, axis)
|
191
199
|
|
@@ -208,7 +216,7 @@ class LazyBuffer:
|
|
208
216
|
|
209
217
|
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
210
218
|
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
211
|
-
return self.
|
219
|
+
return self.const_with_shape(0, new_st.shape)
|
212
220
|
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
213
221
|
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
214
222
|
|