tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/engine/jit.py CHANGED
@@ -1,22 +1,24 @@
1
1
  from __future__ import annotations
2
2
  from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
3
- import functools, itertools, collections
3
+ import functools, collections
4
4
  from tinygrad.tensor import Tensor
5
- from tinygrad.lazy import LazyBuffer
6
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
5
+ from tinygrad.engine.lazy import LazyBuffer
6
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
7
7
  from tinygrad.device import Buffer, Compiled, Device
8
8
  from tinygrad.dtype import DType
9
+ from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
9
10
  from tinygrad.shape.shapetracker import ShapeTracker
10
- from tinygrad.shape.symbolic import Variable, sint
11
11
  from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
12
- from tinygrad.engine.schedule import _internal_memory_planner
12
+ from tinygrad.engine.memory import _internal_memory_planner
13
13
  from tinygrad.nn.state import get_parameters
14
+ from dataclasses import dataclass
14
15
  from weakref import WeakKeyDictionary
15
16
 
16
- def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
17
+ class GraphException(Exception): pass
18
+
19
+ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], max_batch_size=0) -> List[ExecItem]:
17
20
  # Split JIT cache into batches for faster graph execution.
18
21
  # This allows the accelerator to run some batches while subsequent graphs are still being updated.
19
- max_batch_size = getenv("JIT_BATCH_SIZE", 32)
20
22
  graphed_jit_cache: List[ExecItem] = []
21
23
  current_batch: List[ExecItem] = []
22
24
  current_device: Optional[Compiled] = None
@@ -30,10 +32,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
30
32
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
31
33
  graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
32
34
  max_batch_size *= 2
33
- if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
35
+ if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
34
36
  except GraphException as e:
35
37
  graphed_jit_cache.extend(current_batch)
36
- if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
38
+ if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
37
39
  current_batch = []
38
40
  current_device = None
39
41
 
@@ -44,11 +46,11 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
44
46
  elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
45
47
  ji_graph_dev = Device[ji.bufs[0].device]
46
48
 
47
- graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
49
+ graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
48
50
  can_be_graphed = ji_graph_dev and ji_graph_dev.graph
49
51
  can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
50
- type(ji_graph_dev) == type(current_device))
51
- can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
52
+ type(ji_graph_dev) is type(current_device))
53
+ can_extend_graph_batch = can_be_graphed and (max_batch_size == 0 or len(current_batch) < max_batch_size) and can_share_graph
52
54
  if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
53
55
 
54
56
  if can_be_graphed: current_batch.append(ji)
@@ -70,129 +72,224 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
70
72
  class GraphRunner(Runner): # pylint: disable=abstract-method
71
73
  def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
72
74
  self.jit_cache = jit_cache
73
- self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
74
- self.jc_idx_with_updatable_launch_dims = []
75
- self.jc_idx_with_updatable_var_vals = []
75
+ self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
76
+ self.var_vals_replace:Dict[int, List[int]] = {}
77
+ self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
78
+
76
79
  op_estimate: sint = 0
77
80
  mem_estimate: sint = 0
81
+ lds_estimate: sint = 0
82
+
83
+ def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
84
+
85
+ self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
86
+ self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
87
+ [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
88
+ def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
89
+
78
90
  for j,ji in enumerate(jit_cache):
79
91
  op_estimate += ji.prg.op_estimate
80
92
  mem_estimate += ji.prg.mem_estimate
93
+ lds_estimate += ji.prg.lds_estimate
81
94
  if isinstance(ji.prg, CompiledRunner):
82
- if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
83
- if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
84
- self.jc_idx_with_updatable_launch_dims.append(j)
85
- self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
86
- super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
95
+ if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
87
96
 
88
- class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
89
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
90
- self.w_dependency_map: Dict[Any, Any] = {}
91
- self.r_dependency_map: Dict[Any, List[Any]] = collections.defaultdict(list)
92
- super().__init__(jit_cache, input_rawbuffers, var_vals)
97
+ global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
98
+ if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
99
+
100
+ # used in MultiGraphRunner. the ints are id() of _bufs
101
+ self.w_dependency_map: Dict[int, Any] = {}
102
+ self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
103
+
104
+ super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
105
+ ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
106
+
107
+ def updated_vars(self, var_vals: Dict[Variable, int]):
108
+ vals = [var_vals[v] for v in self.vars]
109
+ for j, vidxs in self.var_vals_replace.items():
110
+ for i, v in enumerate(vidxs): yield j, i, vals[v]
93
111
 
94
- def _access_resources(self, read, write, new_dependency:Any):
112
+ def updated_launch_dims(self, var_vals: Dict[Variable, int]):
113
+ dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
114
+ for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
115
+
116
+ def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
95
117
  # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
96
118
  # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
97
119
  wait_nodes = []
98
120
 
99
- for rawbuf in read + write:
121
+ for i,rawbuf in enumerate(rawbufs):
100
122
  if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
101
- for rawbuf in write:
102
- if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
123
+ if i in write:
124
+ if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
125
+ self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
126
+ else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
103
127
 
104
- for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
105
- for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
106
128
  return list({id(x):x for x in wait_nodes}.values())
107
129
 
130
+ # a marker for your graph supporting multiple devices of the same type
131
+ class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
132
+
108
133
  ReturnType = TypeVar('ReturnType')
109
- IN_JIT = ContextVar('IN_JIT', 0)
134
+ @dataclass
135
+ class CapturedJit(Generic[ReturnType]):
136
+ ret: Any # includes the Tensors or any other returned object
137
+ jit_cache: List[ExecItem]
138
+ input_replace: Dict[Tuple[int, int], int]
139
+ extra_view_inputs: List[Tuple[int, int, str, int, DType]]
140
+ expected_names: List[Union[int, str]]
141
+ expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
142
+
143
+ def __reduce__(self):
144
+ return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
145
+ self.expected_names, self.expected_st_vars_dtype_device)
146
+
147
+ def __post_init__(self):
148
+ self._jit_cache: List[ExecItem] = self.jit_cache
149
+ self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
150
+ self._graphed = False
151
+ self._clear_inputs()
152
+
153
+ def _clear_inputs(self):
154
+ for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
155
+
156
+ # jit exec
157
+ def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
158
+ # assign inputs
159
+ for idx, offset, device, size, dtype in self.extra_view_inputs:
160
+ input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
161
+ for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
162
+
163
+ # Condense the items into a graph executor.
164
+ if JIT < 2 and not self._graphed:
165
+ self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
166
+ self._input_replace = get_input_replace(self._jit_cache, input_buffers)
167
+ self._graphed = True
168
+
169
+ if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
170
+ for ei in self._jit_cache: ei.run(var_vals, jit=True)
171
+ self._clear_inputs()
172
+ return self.ret
173
+
174
+ def _prepare_jit_inputs(args, kwargs):
175
+ input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
176
+ names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
177
+ if tensors: Tensor.realize(*tensors)
178
+ lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
179
+ input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
180
+ assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
181
+ st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
182
+ var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
183
+ st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
184
+ return input_buffers, var_vals, names, st_vars_dtype_device
185
+
110
186
  class TinyJit(Generic[ReturnType]):
111
- def __init__(self, fxn:Callable[..., ReturnType]):
187
+ def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
188
+ assert fxn or captured, "need either a function or a CapturedJit"
112
189
  self.fxn = fxn
113
- self.reset()
190
+ self.captured: Optional[CapturedJit] = captured
191
+ self.cnt: int = 2 if self.fxn is None else 0
192
+ self.prune = prune
114
193
 
115
194
  def add_buffer(self, b:Buffer) -> Buffer:
116
- if found:=self.buffer_replace.get(b, None): return found
195
+ if found:=self._buffer_replace.get(b, None): return found
117
196
  if b.is_allocated() or b.lb_refcount > 0: return b
118
197
  if b._base is not None:
119
- self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.buffer_replace.get(b._base, b._base), offset=b.offset)
198
+ self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, base=self.add_buffer(b._base), offset=b.offset)
120
199
  else:
121
- self.buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
200
+ self._buffer_replace[b] = ret = Buffer(b.device, b.size, b.dtype, options=b.options)
122
201
  return ret
123
202
 
124
203
  def add(self, ei:ExecItem):
125
- self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
204
+ self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
126
205
 
127
206
  def reset(self):
128
- self.jit_cache: List[ExecItem] = []
129
- self.input_replace: Dict[Tuple[int, int], int] = {}
130
- self.extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
131
- self.buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
132
- self.cnt: int = 0
207
+ assert self.fxn is not None, "can't reset without function"
208
+ self.cnt = 0
209
+ self.captured = None
210
+
211
+ def __reduce__(self):
212
+ assert self.captured is not None, "can't pickle an uncaptured JIT"
213
+ return self.__class__, (None, self.captured)
214
+
215
+ # keep legacy code working
216
+ @property
217
+ def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
218
+ @property
219
+ def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
133
220
 
134
221
  def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
135
222
 
136
223
  def __call__(self, *args, **kwargs) -> ReturnType:
137
- input_tensors: List[Tuple[Union[int, str], Tensor]] = \
138
- [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
139
- if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
140
- names: List[Union[int, str]] = [name for name,_ in input_tensors]
141
- lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
142
- st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
143
- input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
144
- assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
145
- var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
146
- [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
147
- st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
224
+ input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
148
225
  if not JIT or self.cnt == 0:
149
- if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
150
226
  # jit ignore
151
- with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
152
- self.ret = self.fxn(*args, **kwargs)
153
- if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
227
+ assert self.fxn is not None
228
+ with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
229
+ ret = self.fxn(*args, **kwargs)
230
+ if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
154
231
  elif self.cnt == 1:
155
232
  # jit capture
156
- self.expected_names: List[Union[int, str]] = names
157
- self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
158
- with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
233
+ assert self.fxn is not None
234
+ if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
235
+ self._jit_cache: List[ExecItem] = []
236
+ self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
237
+ # TODO: should we always disable the memory planner here? it must be off for prune
238
+ with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
159
239
  capturing.append(self)
160
- self.ret = self.fxn(*args, **kwargs)
161
- if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
162
- capturing.clear()
163
- del self.buffer_replace
164
- assert len(self.jit_cache), "didn't JIT anything!"
165
- if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
240
+ try:
241
+ ret = self.fxn(*args, **kwargs)
242
+ if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
243
+ except Exception as e: raise e
244
+ finally: capturing.clear()
245
+ jit_cache = self._jit_cache
246
+ del self._buffer_replace, self._jit_cache
247
+ assert len(jit_cache), "didn't JIT anything!"
248
+ if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
166
249
 
167
250
  # track inputs that are views of buffers
168
- for item in self.jit_cache:
251
+ # TODO: eventually expected_buffers should live in ExecItem
252
+ extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
253
+ for item in jit_cache:
169
254
  for b in item.bufs:
170
255
  if b is not None and b._base is not None and b._base in input_buffers:
171
256
  input_buffers.append(b)
172
- self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
257
+ extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
258
+
259
+ # prune independent kernels (optional)
260
+ if self.prune:
261
+ depends = set(input_buffers)
262
+ for ei in jit_cache:
263
+ if any(b in depends for b in ei.bufs):
264
+ if isinstance(ei.prg, CompiledRunner):
265
+ depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
266
+ pruned, onetime = partition(jit_cache,
267
+ lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
268
+ if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
269
+ # run the onetime kernels here
270
+ for ei in onetime:
271
+ for b in ei.bufs: cast(Buffer, b).ensure_allocated()
272
+ ei.run(var_vals, jit=True)
273
+ jit_cache = pruned
173
274
 
174
275
  # memory planning (optional)
175
- assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
176
- self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
276
+ # Exclude buffers involved in transfer ops to preserve parallelism.
277
+ noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
278
+ assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
279
+ jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
177
280
 
178
- # Condense the items into a graph executor.
179
- if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
281
+ input_replace = get_input_replace(jit_cache, input_buffers)
282
+ if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
180
283
 
181
- self.input_replace = get_input_replace(self.jit_cache, input_buffers)
182
- if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
284
+ # set this for next run
285
+ self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
183
286
  elif self.cnt >= 2:
184
287
  # jit exec
185
- assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
186
- assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
187
- f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
188
- for idx, offset, device, size, dtype in self.extra_view_inputs:
189
- input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
190
- for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
191
- if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
192
- for ei in self.jit_cache: ei.run(var_vals, jit=True)
193
-
194
- # clear jit inputs
195
- for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
288
+ assert self.captured is not None
289
+ assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
290
+ assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
291
+ f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
292
+ ret = self.captured(input_buffers, var_vals)
196
293
 
197
294
  self.cnt += 1
198
- return self.ret
295
+ return ret
@@ -1,44 +1,44 @@
1
1
  from __future__ import annotations
2
- import math
3
- from typing import Union, Optional, Any, Tuple, List
4
- from tinygrad.dtype import dtypes, DType, ConstType
5
- from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG
6
- from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
7
- from tinygrad.shape.symbolic import sint, Variable
2
+ from typing import Optional, Any, Tuple, List, get_args
3
+ from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
4
+ from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
5
+ from tinygrad.ops import exec_alu, python_alu
6
+ from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
8
7
  from tinygrad.shape.shapetracker import ShapeTracker
9
8
  from tinygrad.device import Buffer
10
9
  from weakref import ref, ReferenceType, WeakValueDictionary
11
10
 
12
11
  lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
13
- def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
14
- base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
15
- if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
16
- if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, Variable) else arg, True
12
+ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
13
+ base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)):
14
+ if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None
15
+ dtype = to_dtype(dtype)
16
+ if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True
17
17
 
18
18
  cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
19
- if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
19
+ if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret
20
20
 
21
- ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
21
+ ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base, metadata=_METADATA.get())
22
22
  if enable_cache: lazycache[cache_key] = ret
23
23
  return ret
24
24
 
25
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
26
- class LazyBuffer:
25
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
26
+ class LazyBuffer(MathTrait):
27
27
  def __init__(self, device:str, st:ShapeTracker, dtype:DType,
28
- op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
29
- base:Optional[LazyBuffer]=None):
30
- self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
28
+ op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
29
+ base:Optional[LazyBuffer]=None, metadata:Optional[Metadata]=None):
30
+ self.device, self.st, self.dtype, self.shape, self.size, self.metadata = device, st, to_dtype(dtype), st.shape, st.size, metadata
31
31
  self._base: Optional[LazyBuffer] = None
32
32
  if base is None:
33
33
  # properties on base
34
- self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
35
- assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
34
+ self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps
35
+ assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized"
36
36
 
37
- if self.op is LoadOps.VIEW:
37
+ if self.op is Ops.BUFFER_VIEW:
38
38
  # some LazyBuffers can be processed with only a view, no AST required
39
- self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
39
+ self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
40
40
  else:
41
- self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
41
+ self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype)
42
42
  self.buffer.ref(1)
43
43
  self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
44
44
  self.forced_realize = False
@@ -51,7 +51,7 @@ class LazyBuffer:
51
51
  if hasattr(self, 'buffer'): self.buffer.ref(-1)
52
52
 
53
53
  def __repr__(self) -> str:
54
- return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base != self else (self.op, self.realized)}>"
54
+ return f"<LB {self.device} {self.shape} {str(self.dtype)[7:]} {self.st if self.base is not self else (self.op, self.realized)}>"
55
55
 
56
56
  @property
57
57
  def realized(self) -> Optional[Buffer]:
@@ -67,36 +67,42 @@ class LazyBuffer:
67
67
  def lbs(self) -> List[LazyBuffer]: return [self]
68
68
 
69
69
  @staticmethod
70
- def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
70
+ def metaop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer:
71
71
  assert isinstance(src, tuple)
72
72
  return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
73
73
 
74
- def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
75
- assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
76
- shape = self.shape if shape is None else shape
77
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
74
+ def const_like(self, b): return self.const_with_shape(b, self.shape)
75
+ def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer:
76
+ assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType"
77
+ return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
78
78
 
79
+ @property
79
80
  def is_realized(self) -> bool: return self.base.realized is not None
80
81
 
81
82
  def assign(self, x:LazyBuffer) -> LazyBuffer:
82
83
  assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
83
- return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
84
+ assert self.is_realized, f"assign target must be realized {self}"
85
+ return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,),
86
+ src=(self.base, x), enable_cache=True)
84
87
 
85
- def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
88
+ def can_view(self):
89
+ return (self.st.consecutive and not self.is_unrealized_const() and not isinstance(self.dtype, ImageDType) and
90
+ self.device.split(":")[0] in view_supported_devices)
86
91
 
87
92
  def contiguous(self, allow_buffer_view=True):
88
93
  if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
89
- ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
94
+ ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS)
90
95
  if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
91
96
  return ret
92
97
  self.base.forced_realize = True
93
98
  return self
94
99
 
95
- def cast(self, dtype:DType, bitcast:bool=False):
100
+ def bitcast(self, dtype:DType) -> LazyBuffer: return self.cast(dtype, bitcast=True)
101
+ def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True) -> LazyBuffer:
96
102
  if self.dtype == dtype: return self
97
103
  if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
98
104
  if self.is_unrealized_unmasked_const() and not bitcast:
99
- return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
105
+ return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype))
100
106
  new_shape = self.shape
101
107
  if bitcast and self.dtype.itemsize != dtype.itemsize:
102
108
  if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
@@ -104,29 +110,30 @@ class LazyBuffer:
104
110
  # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
105
111
  if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
106
112
  new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
107
- elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
113
+ elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
108
114
  # TODO: applying this makes gpt2 slower
109
115
  return self.base.cast(dtype, bitcast)._view(self.st)
110
- cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
116
+ cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
111
117
  return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
112
118
 
113
- def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
119
+ def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
114
120
  def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
115
121
 
116
122
  def _copy(self, device:str) -> LazyBuffer:
117
- return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, self.buffer.nbytes, (self,), enable_cache=False)
123
+ assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
124
+ return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
118
125
 
119
- def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
126
+ def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
120
127
  # no COPY
121
- if self.device == device: return self
128
+ if self.device == device and not clone: return self
122
129
 
123
130
  # double COPY = one COPY
124
- if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
131
+ if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY:
125
132
  return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
126
133
 
127
134
  # const doesn't have to be copied (issues with disk tensor)
128
135
  if self.is_unrealized_const():
129
- return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
136
+ return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
130
137
 
131
138
  # if it's a shrink, do the shrink before the copy with CONTIGUOUS
132
139
  if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
@@ -134,58 +141,59 @@ class LazyBuffer:
134
141
  # copy the base and apply the shapetracker on the new device
135
142
  return self.base._copy(device)._view(self.st)
136
143
 
137
- def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
144
+ def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
145
+
146
+ def alu(self, op:Ops, *in_srcs:LazyBuffer) -> LazyBuffer:
138
147
  srcs: List[LazyBuffer] = []
139
148
  for s in (self,)+in_srcs:
140
149
  if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
141
150
  srcs.append(root._view(s.base.contiguous_child[1]))
142
151
  else:
143
152
  srcs.append(s)
144
- assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
153
+ if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
154
+ raise AssertionError(f"all dtypes must match {dts} on {op}")
145
155
  assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
146
- if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
147
- if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
156
+ if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool"
148
157
 
149
- out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
158
+ out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype
150
159
 
151
160
  # const folding
152
161
  if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
153
- return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
154
- if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
155
- if op in BinaryOps:
162
+ return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
163
+ if op in GroupOp.Binary:
156
164
  x, y = self, in_srcs[0]
157
- if op is BinaryOps.ADD:
165
+ if op is Ops.ADD:
158
166
  if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
159
167
  if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
160
- if op is BinaryOps.MUL:
161
- if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
162
- return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
163
- if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
164
- return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
168
+ if op is Ops.MUL:
169
+ if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0)
170
+ if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0)
171
+ if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
165
172
 
166
- return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
173
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs))
167
174
 
168
175
  # *** reduce ops ***
169
176
 
170
- def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
177
+ def _reduce_op(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
171
178
  assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
172
- axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
179
+ axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
173
180
  if len(axis) == 0: return self
174
- new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
175
- return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
181
+ return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, Ops.REDUCE_AXIS, (op, axis), (self,))
176
182
 
177
- def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
178
- new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
183
+ def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
184
+ new_shape = self.st.reduce(axis)
179
185
  # TODO: this logic should move to the scheduler
180
- if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
186
+ if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
181
187
 
182
188
  # const folding
183
189
  # TODO: fold this for symbolic?
184
190
  if self.is_unrealized_unmasked_const() and all_int(self.shape):
185
- return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
191
+ if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
192
+ if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
193
+ if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
186
194
 
187
195
  # TODO: can we split symbolic shape if the reduce axis is not symbolic?
188
- if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
196
+ if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
189
197
  prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
190
198
  return self._reduce_op(op, axis)
191
199
 
@@ -208,7 +216,7 @@ class LazyBuffer:
208
216
 
209
217
  def _view(self, new_st:ShapeTracker) -> LazyBuffer:
210
218
  if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
211
- return self.const(0, new_st.shape)
219
+ return self.const_with_shape(0, new_st.shape)
212
220
  if new_st.contiguous and self.base.shape == new_st.shape: return self.base
213
221
  return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
214
222