tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/jit.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1
|
-
from
|
2
|
-
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
|
1
|
+
from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
|
3
2
|
import functools, collections
|
4
3
|
from tinygrad.tensor import Tensor
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
|
4
|
+
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
|
7
5
|
from tinygrad.device import Buffer, Compiled, Device
|
8
6
|
from tinygrad.dtype import DType
|
9
|
-
from tinygrad.ops import UOp,
|
7
|
+
from tinygrad.ops import UOp, Variable, sym_infer, Ops
|
10
8
|
from tinygrad.shape.shapetracker import ShapeTracker
|
11
|
-
from tinygrad.engine.realize import ExecItem, capturing,
|
9
|
+
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
12
10
|
from tinygrad.engine.memory import _internal_memory_planner
|
13
11
|
from tinygrad.nn.state import get_parameters
|
14
12
|
from dataclasses import dataclass
|
@@ -16,11 +14,11 @@ from weakref import WeakKeyDictionary
|
|
16
14
|
|
17
15
|
class GraphException(Exception): pass
|
18
16
|
|
19
|
-
def apply_graph_to_jit(jit_cache:
|
17
|
+
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
|
20
18
|
# Split JIT cache into batches for faster graph execution.
|
21
19
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
22
|
-
graphed_jit_cache:
|
23
|
-
current_batch:
|
20
|
+
graphed_jit_cache: list[ExecItem] = []
|
21
|
+
current_batch: list[ExecItem] = []
|
24
22
|
current_device: Optional[Compiled] = None
|
25
23
|
|
26
24
|
def flush_batch():
|
@@ -30,7 +28,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
30
28
|
graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
|
31
29
|
# clear jit inputs to allow their memory to be freed/reused
|
32
30
|
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
33
|
-
graphed_jit_cache.append(ExecItem(graph_runner, cast(
|
31
|
+
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers)))
|
34
32
|
max_batch_size *= 2
|
35
33
|
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
36
34
|
except GraphException as e:
|
@@ -40,9 +38,9 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
40
38
|
current_device = None
|
41
39
|
|
42
40
|
for ji in jit_cache:
|
43
|
-
if ji.prg
|
41
|
+
if isinstance(ji.prg, ViewOp): continue
|
44
42
|
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
45
|
-
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.
|
43
|
+
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
|
46
44
|
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
47
45
|
ji_graph_dev = Device[ji.bufs[0].device]
|
48
46
|
|
@@ -61,24 +59,21 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|
61
59
|
if len(current_batch) > 0: flush_batch()
|
62
60
|
return graphed_jit_cache
|
63
61
|
|
64
|
-
def get_input_replace(jit_cache:
|
65
|
-
input_replace:
|
62
|
+
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
63
|
+
input_replace: dict[tuple[int, int], int] = {}
|
66
64
|
for j,ji in enumerate(jit_cache):
|
67
65
|
for i,a in enumerate(ji.bufs):
|
68
66
|
if a in input_rawbuffers:
|
69
67
|
input_replace[(j,i)] = input_rawbuffers.index(a)
|
70
68
|
return input_replace
|
71
69
|
|
72
|
-
class GraphRunner(Runner):
|
73
|
-
def __init__(self, jit_cache:
|
74
|
-
self.jit_cache = jit_cache
|
75
|
-
self.input_replace:
|
76
|
-
self.var_vals_replace:
|
77
|
-
self.launch_dims_replace:
|
78
|
-
|
79
|
-
op_estimate: sint = 0
|
80
|
-
mem_estimate: sint = 0
|
81
|
-
lds_estimate: sint = 0
|
70
|
+
class GraphRunner(Runner):
|
71
|
+
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
|
72
|
+
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
73
|
+
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
74
|
+
self.var_vals_replace:dict[int, list[int]] = {}
|
75
|
+
self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
|
76
|
+
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
82
77
|
|
83
78
|
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
84
79
|
|
@@ -87,33 +82,35 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
|
87
82
|
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
88
83
|
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
89
84
|
|
85
|
+
estimates = Estimates()
|
90
86
|
for j,ji in enumerate(jit_cache):
|
91
|
-
|
92
|
-
mem_estimate += ji.prg.mem_estimate
|
93
|
-
lds_estimate += ji.prg.lds_estimate
|
87
|
+
estimates += ji.prg.estimates
|
94
88
|
if isinstance(ji.prg, CompiledRunner):
|
95
89
|
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
96
90
|
|
97
91
|
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
98
|
-
if global_dim_idx is not None or local_dim_idx is not None:
|
92
|
+
if global_dim_idx is not None or local_dim_idx is not None:
|
93
|
+
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
94
|
+
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
|
95
|
+
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
99
96
|
|
100
97
|
# used in MultiGraphRunner. the ints are id() of _bufs
|
101
|
-
self.w_dependency_map:
|
102
|
-
self.r_dependency_map:
|
98
|
+
self.w_dependency_map: dict[int, Any] = {}
|
99
|
+
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
|
103
100
|
|
104
|
-
super().__init__(colored(f"<batched {len(
|
105
|
-
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
101
|
+
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
106
102
|
|
107
|
-
def updated_vars(self, var_vals:
|
103
|
+
def updated_vars(self, var_vals: dict[Variable, int]):
|
108
104
|
vals = [var_vals[v] for v in self.vars]
|
109
105
|
for j, vidxs in self.var_vals_replace.items():
|
110
106
|
for i, v in enumerate(vidxs): yield j, i, vals[v]
|
111
107
|
|
112
|
-
def updated_launch_dims(self, var_vals:
|
108
|
+
def updated_launch_dims(self, var_vals: dict[Variable, int]):
|
113
109
|
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
114
|
-
for j, (gl, lc) in self.launch_dims_replace.items():
|
110
|
+
for j, (gl, lc) in self.launch_dims_replace.items():
|
111
|
+
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
115
112
|
|
116
|
-
def _access_resources(self, rawbufs:
|
113
|
+
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
|
117
114
|
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
118
115
|
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
119
116
|
wait_nodes = []
|
@@ -128,43 +125,65 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
|
128
125
|
return list({id(x):x for x in wait_nodes}.values())
|
129
126
|
|
130
127
|
# a marker for your graph supporting multiple devices of the same type
|
131
|
-
class MultiGraphRunner(GraphRunner): pass
|
128
|
+
class MultiGraphRunner(GraphRunner): pass
|
129
|
+
|
130
|
+
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
131
|
+
for ei in jit_cache:
|
132
|
+
if any(b in depends for b in ei.bufs):
|
133
|
+
if isinstance(ei.prg, CompiledRunner):
|
134
|
+
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
|
135
|
+
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
136
|
+
depends.add(cast(Buffer, ei.bufs[0]))
|
132
137
|
|
133
138
|
ReturnType = TypeVar('ReturnType')
|
134
139
|
@dataclass
|
135
140
|
class CapturedJit(Generic[ReturnType]):
|
136
141
|
ret: Any # includes the Tensors or any other returned object
|
137
|
-
jit_cache:
|
138
|
-
input_replace:
|
139
|
-
extra_view_inputs:
|
140
|
-
expected_names:
|
141
|
-
expected_st_vars_dtype_device:
|
142
|
+
jit_cache: list[ExecItem]
|
143
|
+
input_replace: dict[tuple[int, int], int]
|
144
|
+
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
145
|
+
expected_names: list[Union[int, str]]
|
146
|
+
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
|
142
147
|
|
143
148
|
def __reduce__(self):
|
149
|
+
# TODO: free_intermediates here?
|
144
150
|
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
|
145
151
|
self.expected_names, self.expected_st_vars_dtype_device)
|
146
152
|
|
147
153
|
def __post_init__(self):
|
148
|
-
self._jit_cache:
|
149
|
-
self._input_replace:
|
150
|
-
self.
|
154
|
+
self._jit_cache: list[ExecItem] = self.jit_cache
|
155
|
+
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
156
|
+
self._first_run = True
|
151
157
|
self._clear_inputs()
|
152
158
|
|
153
159
|
def _clear_inputs(self):
|
154
160
|
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
155
161
|
|
162
|
+
def free_intermediates(self):
|
163
|
+
depends: set[Buffer|None] = set([None])
|
164
|
+
update_depends(depends, self.jit_cache)
|
165
|
+
for b in depends:
|
166
|
+
if b is not None: b.deallocate()
|
167
|
+
self.__post_init__() # reset the graph state
|
168
|
+
|
156
169
|
# jit exec
|
157
|
-
def __call__(self, input_buffers:
|
170
|
+
def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
|
158
171
|
# assign inputs
|
159
172
|
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
160
173
|
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
161
174
|
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
162
175
|
|
163
176
|
# Condense the items into a graph executor.
|
164
|
-
if
|
165
|
-
|
166
|
-
|
167
|
-
|
177
|
+
if self._first_run:
|
178
|
+
# allocate intermediates if freed
|
179
|
+
for ji in self.jit_cache:
|
180
|
+
for b in ji.bufs:
|
181
|
+
if b is not None: b.ensure_allocated()
|
182
|
+
# create graph if needed
|
183
|
+
if JIT < 2:
|
184
|
+
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
|
185
|
+
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
186
|
+
self._first_run = False
|
168
187
|
|
169
188
|
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
170
189
|
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
@@ -172,13 +191,14 @@ class CapturedJit(Generic[ReturnType]):
|
|
172
191
|
return self.ret
|
173
192
|
|
174
193
|
def _prepare_jit_inputs(args, kwargs):
|
175
|
-
input_tensors:
|
194
|
+
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
176
195
|
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
177
196
|
if tensors: Tensor.realize(*tensors)
|
178
|
-
|
179
|
-
|
197
|
+
# TODO: should we be unpacking multi here?
|
198
|
+
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
|
199
|
+
input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
180
200
|
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
181
|
-
st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
201
|
+
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
182
202
|
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
183
203
|
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
184
204
|
return input_buffers, var_vals, names, st_vars_dtype_device
|
@@ -214,9 +234,9 @@ class TinyJit(Generic[ReturnType]):
|
|
214
234
|
|
215
235
|
# keep legacy code working
|
216
236
|
@property
|
217
|
-
def jit_cache(self) ->
|
237
|
+
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
218
238
|
@property
|
219
|
-
def input_replace(self) ->
|
239
|
+
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
220
240
|
|
221
241
|
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
222
242
|
|
@@ -232,7 +252,7 @@ class TinyJit(Generic[ReturnType]):
|
|
232
252
|
# jit capture
|
233
253
|
assert self.fxn is not None
|
234
254
|
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
|
235
|
-
self._jit_cache:
|
255
|
+
self._jit_cache: list[ExecItem] = []
|
236
256
|
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
|
237
257
|
# TODO: should we always disable the memory planner here? it must be off for prune
|
238
258
|
with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
|
@@ -249,7 +269,7 @@ class TinyJit(Generic[ReturnType]):
|
|
249
269
|
|
250
270
|
# track inputs that are views of buffers
|
251
271
|
# TODO: eventually expected_buffers should live in ExecItem
|
252
|
-
extra_view_inputs:
|
272
|
+
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
253
273
|
for item in jit_cache:
|
254
274
|
for b in item.bufs:
|
255
275
|
if b is not None and b._base is not None and b._base in input_buffers:
|
@@ -259,10 +279,7 @@ class TinyJit(Generic[ReturnType]):
|
|
259
279
|
# prune independent kernels (optional)
|
260
280
|
if self.prune:
|
261
281
|
depends = set(input_buffers)
|
262
|
-
|
263
|
-
if any(b in depends for b in ei.bufs):
|
264
|
-
if isinstance(ei.prg, CompiledRunner):
|
265
|
-
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
282
|
+
update_depends(depends, jit_cache)
|
266
283
|
pruned, onetime = partition(jit_cache,
|
267
284
|
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
268
285
|
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
@@ -275,7 +292,7 @@ class TinyJit(Generic[ReturnType]):
|
|
275
292
|
# memory planning (optional)
|
276
293
|
# Exclude buffers involved in transfer ops to preserve parallelism.
|
277
294
|
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
278
|
-
assigned = _internal_memory_planner([cast(
|
295
|
+
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
279
296
|
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
280
297
|
|
281
298
|
input_replace = get_input_replace(jit_cache, input_buffers)
|
tinygrad/engine/memory.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
from typing import List, Union, Tuple, Dict
|
2
1
|
from collections import defaultdict
|
3
2
|
from tinygrad.engine.schedule import ScheduleItem
|
4
3
|
from tinygrad.device import Device, Buffer
|
@@ -7,7 +6,7 @@ from tinygrad.ops import Ops
|
|
7
6
|
|
8
7
|
# **************** memory planning ****************
|
9
8
|
|
10
|
-
def _internal_memory_planner(buffers:
|
9
|
+
def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
|
11
10
|
if NO_MEMORY_PLANNER: return {}
|
12
11
|
first_appearance, last_appearance = {}, {}
|
13
12
|
for i,u in enumerate(buffers):
|
@@ -18,7 +17,7 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
|
|
18
17
|
|
19
18
|
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
|
20
19
|
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
|
21
|
-
free_segs:
|
20
|
+
free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
|
22
21
|
def find_replace_buffer(buf, st, en):
|
23
22
|
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
|
24
23
|
|
@@ -44,8 +43,8 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
|
|
44
43
|
f"{len(ak)} -> {len(av)} bufs")
|
45
44
|
return assigned
|
46
45
|
|
47
|
-
def memory_planner(schedule:
|
46
|
+
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
48
47
|
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
49
48
|
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
50
49
|
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
51
|
-
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata
|
50
|
+
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
tinygrad/engine/multi.py
ADDED
@@ -0,0 +1,162 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import functools, itertools, operator
|
3
|
+
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
4
|
+
from tinygrad.ops import Ops, UOp, sint
|
5
|
+
|
6
|
+
def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
7
|
+
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
8
|
+
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
9
|
+
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
10
|
+
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
11
|
+
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
12
|
+
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
13
|
+
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
14
|
+
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
15
|
+
|
16
|
+
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
17
|
+
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
18
|
+
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
19
|
+
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
20
|
+
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
21
|
+
|
22
|
+
# scatter-reduce
|
23
|
+
for step in range(n_lbs-1):
|
24
|
+
for i in range(len(chunks)):
|
25
|
+
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
26
|
+
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device))
|
27
|
+
|
28
|
+
# allgather
|
29
|
+
for step in range(n_lbs-1):
|
30
|
+
for i in range(len(chunks)):
|
31
|
+
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
|
32
|
+
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device)
|
33
|
+
|
34
|
+
# assemble chunks back
|
35
|
+
pads = [((s,numel-e),) for s,e in chunks]
|
36
|
+
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
37
|
+
|
38
|
+
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
|
39
|
+
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
40
|
+
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
41
|
+
|
42
|
+
# ***** multi functions *****
|
43
|
+
|
44
|
+
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
|
45
|
+
|
46
|
+
def alu_multi(root:UOp):
|
47
|
+
msrcs = root.src
|
48
|
+
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
|
49
|
+
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
50
|
+
|
51
|
+
axis = root.axis
|
52
|
+
bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
|
53
|
+
srcs:list[list[UOp]] = []
|
54
|
+
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
55
|
+
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
|
56
|
+
for mlb in msrcs:
|
57
|
+
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
|
58
|
+
else:
|
59
|
+
assert axis is not None and bounds is not None
|
60
|
+
if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
|
61
|
+
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
|
62
|
+
new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
|
63
|
+
new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
|
64
|
+
return UOp.multi(*new_lbs, axis=axis, real=new_real)
|
65
|
+
|
66
|
+
def reduce_multi(root:UOp, multi:UOp):
|
67
|
+
op, axis = root.arg
|
68
|
+
if multi.axis is not None and multi.axis in axis:
|
69
|
+
# all-reduce on sharded axes
|
70
|
+
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
|
71
|
+
# if all partitions are real, do all_reduce
|
72
|
+
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis)
|
73
|
+
# only one partition is real, keep it
|
74
|
+
return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real)
|
75
|
+
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
76
|
+
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real)
|
77
|
+
|
78
|
+
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
79
|
+
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
80
|
+
|
81
|
+
def reshape_multi(root:UOp, multi:UOp):
|
82
|
+
arg = root.arg
|
83
|
+
if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
|
84
|
+
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
85
|
+
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
|
86
|
+
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
|
87
|
+
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
|
88
|
+
return UOp.multi(*lbs, axis=new_axis, real=multi.real)
|
89
|
+
|
90
|
+
def expand_multi(root:UOp, multi:UOp):
|
91
|
+
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
92
|
+
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
|
93
|
+
return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
|
94
|
+
|
95
|
+
def pad_multi(root:UOp, multi:UOp):
|
96
|
+
assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
|
97
|
+
# pad on shard axis -> fill others with zeros and set real to all True
|
98
|
+
if multi.axis is not None and root.arg[multi.axis] != (0,0):
|
99
|
+
# pad back to whole axis, remove real mask
|
100
|
+
assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
|
101
|
+
dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
|
102
|
+
assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
103
|
+
return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
|
104
|
+
return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
105
|
+
|
106
|
+
def permute_multi(root:UOp, multi:UOp):
|
107
|
+
# all permutes supported!
|
108
|
+
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real)
|
109
|
+
|
110
|
+
def shrink_multi(root:UOp, multi:UOp):
|
111
|
+
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
|
112
|
+
f"shrinking not supported for {root.arg=}"
|
113
|
+
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
|
114
|
+
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
|
115
|
+
"cannot shrink sharded and non-sharded axis at the same time"
|
116
|
+
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
117
|
+
idx = multi.bounds.index(root.arg[multi.axis])
|
118
|
+
# zero out other lbs to not create lb reference
|
119
|
+
return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
|
120
|
+
axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
|
121
|
+
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
|
122
|
+
axis=multi.axis, real=multi.real)
|
123
|
+
|
124
|
+
def flip_multi(root:UOp, multi:UOp):
|
125
|
+
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
|
126
|
+
return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
127
|
+
|
128
|
+
def copy_multi(multi:UOp, device:UOp):
|
129
|
+
# if we already have a copy on the device, return that
|
130
|
+
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
|
131
|
+
# copy lbs to device, pad to final shape, and sum
|
132
|
+
llbs:list[UOp] = []
|
133
|
+
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
|
134
|
+
if not real: continue
|
135
|
+
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
136
|
+
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
|
137
|
+
return functools.reduce(operator.add, llbs)
|
138
|
+
|
139
|
+
def assign_multi(dest:UOp, src:UOp):
|
140
|
+
assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
|
141
|
+
return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
|
142
|
+
|
143
|
+
def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
|
144
|
+
|
145
|
+
# NOTE: this is the same pattern as Ops.UNROLL
|
146
|
+
multi_pm = PatternMatcher([
|
147
|
+
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
148
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
149
|
+
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
150
|
+
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
151
|
+
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
152
|
+
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
153
|
+
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
154
|
+
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
155
|
+
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
156
|
+
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
|
157
|
+
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
158
|
+
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
159
|
+
])
|
160
|
+
|
161
|
+
@track_rewrites(named=True)
|
162
|
+
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}
|