tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/jit.py CHANGED
@@ -1,14 +1,12 @@
1
- from __future__ import annotations
2
- from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
1
+ from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
3
2
  import functools, collections
4
3
  from tinygrad.tensor import Tensor
5
- from tinygrad.engine.lazy import LazyBuffer
6
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
4
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
7
5
  from tinygrad.device import Buffer, Compiled, Device
8
6
  from tinygrad.dtype import DType
9
- from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
7
+ from tinygrad.ops import UOp, Variable, sym_infer, Ops
10
8
  from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
9
+ from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
12
10
  from tinygrad.engine.memory import _internal_memory_planner
13
11
  from tinygrad.nn.state import get_parameters
14
12
  from dataclasses import dataclass
@@ -16,11 +14,11 @@ from weakref import WeakKeyDictionary
16
14
 
17
15
  class GraphException(Exception): pass
18
16
 
19
- def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], max_batch_size=0) -> List[ExecItem]:
17
+ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
20
18
  # Split JIT cache into batches for faster graph execution.
21
19
  # This allows the accelerator to run some batches while subsequent graphs are still being updated.
22
- graphed_jit_cache: List[ExecItem] = []
23
- current_batch: List[ExecItem] = []
20
+ graphed_jit_cache: list[ExecItem] = []
21
+ current_batch: list[ExecItem] = []
24
22
  current_device: Optional[Compiled] = None
25
23
 
26
24
  def flush_batch():
@@ -30,7 +28,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
30
28
  graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
31
29
  # clear jit inputs to allow their memory to be freed/reused
32
30
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
33
- graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
31
+ graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers)))
34
32
  max_batch_size *= 2
35
33
  if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
36
34
  except GraphException as e:
@@ -40,9 +38,9 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
40
38
  current_device = None
41
39
 
42
40
  for ji in jit_cache:
43
- if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
41
+ if isinstance(ji.prg, ViewOp): continue
44
42
  ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
45
- if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
43
+ if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
46
44
  elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
47
45
  ji_graph_dev = Device[ji.bufs[0].device]
48
46
 
@@ -61,24 +59,21 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
61
59
  if len(current_batch) > 0: flush_batch()
62
60
  return graphed_jit_cache
63
61
 
64
- def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
65
- input_replace: Dict[Tuple[int, int], int] = {}
62
+ def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
63
+ input_replace: dict[tuple[int, int], int] = {}
66
64
  for j,ji in enumerate(jit_cache):
67
65
  for i,a in enumerate(ji.bufs):
68
66
  if a in input_rawbuffers:
69
67
  input_replace[(j,i)] = input_rawbuffers.index(a)
70
68
  return input_replace
71
69
 
72
- class GraphRunner(Runner): # pylint: disable=abstract-method
73
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
74
- self.jit_cache = jit_cache
75
- self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
76
- self.var_vals_replace:Dict[int, List[int]] = {}
77
- self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
78
-
79
- op_estimate: sint = 0
80
- mem_estimate: sint = 0
81
- lds_estimate: sint = 0
70
+ class GraphRunner(Runner):
71
+ def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
72
+ self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
73
+ self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
74
+ self.var_vals_replace:dict[int, list[int]] = {}
75
+ self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
76
+ self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
82
77
 
83
78
  def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
84
79
 
@@ -87,33 +82,35 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
87
82
  [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
88
83
  def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
89
84
 
85
+ estimates = Estimates()
90
86
  for j,ji in enumerate(jit_cache):
91
- op_estimate += ji.prg.op_estimate
92
- mem_estimate += ji.prg.mem_estimate
93
- lds_estimate += ji.prg.lds_estimate
87
+ estimates += ji.prg.estimates
94
88
  if isinstance(ji.prg, CompiledRunner):
95
89
  if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
96
90
 
97
91
  global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
98
- if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
92
+ if global_dim_idx is not None or local_dim_idx is not None:
93
+ self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
94
+ assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
95
+ self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
99
96
 
100
97
  # used in MultiGraphRunner. the ints are id() of _bufs
101
- self.w_dependency_map: Dict[int, Any] = {}
102
- self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
98
+ self.w_dependency_map: dict[int, Any] = {}
99
+ self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
103
100
 
104
- super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
105
- ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
101
+ super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
106
102
 
107
- def updated_vars(self, var_vals: Dict[Variable, int]):
103
+ def updated_vars(self, var_vals: dict[Variable, int]):
108
104
  vals = [var_vals[v] for v in self.vars]
109
105
  for j, vidxs in self.var_vals_replace.items():
110
106
  for i, v in enumerate(vidxs): yield j, i, vals[v]
111
107
 
112
- def updated_launch_dims(self, var_vals: Dict[Variable, int]):
108
+ def updated_launch_dims(self, var_vals: dict[Variable, int]):
113
109
  dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
114
- for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
110
+ for j, (gl, lc) in self.launch_dims_replace.items():
111
+ yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
115
112
 
116
- def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
113
+ def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
117
114
  # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
118
115
  # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
119
116
  wait_nodes = []
@@ -128,43 +125,65 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
128
125
  return list({id(x):x for x in wait_nodes}.values())
129
126
 
130
127
  # a marker for your graph supporting multiple devices of the same type
131
- class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
128
+ class MultiGraphRunner(GraphRunner): pass
129
+
130
+ def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
131
+ for ei in jit_cache:
132
+ if any(b in depends for b in ei.bufs):
133
+ if isinstance(ei.prg, CompiledRunner):
134
+ depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
135
+ if isinstance(ei.prg, (BufferCopy, BufferXfer)):
136
+ depends.add(cast(Buffer, ei.bufs[0]))
132
137
 
133
138
  ReturnType = TypeVar('ReturnType')
134
139
  @dataclass
135
140
  class CapturedJit(Generic[ReturnType]):
136
141
  ret: Any # includes the Tensors or any other returned object
137
- jit_cache: List[ExecItem]
138
- input_replace: Dict[Tuple[int, int], int]
139
- extra_view_inputs: List[Tuple[int, int, str, int, DType]]
140
- expected_names: List[Union[int, str]]
141
- expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
142
+ jit_cache: list[ExecItem]
143
+ input_replace: dict[tuple[int, int], int]
144
+ extra_view_inputs: list[tuple[int, int, str, int, DType]]
145
+ expected_names: list[Union[int, str]]
146
+ expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
142
147
 
143
148
  def __reduce__(self):
149
+ # TODO: free_intermediates here?
144
150
  return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
145
151
  self.expected_names, self.expected_st_vars_dtype_device)
146
152
 
147
153
  def __post_init__(self):
148
- self._jit_cache: List[ExecItem] = self.jit_cache
149
- self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
150
- self._graphed = False
154
+ self._jit_cache: list[ExecItem] = self.jit_cache
155
+ self._input_replace: dict[tuple[int, int], int] = self.input_replace
156
+ self._first_run = True
151
157
  self._clear_inputs()
152
158
 
153
159
  def _clear_inputs(self):
154
160
  for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
155
161
 
162
+ def free_intermediates(self):
163
+ depends: set[Buffer|None] = set([None])
164
+ update_depends(depends, self.jit_cache)
165
+ for b in depends:
166
+ if b is not None: b.deallocate()
167
+ self.__post_init__() # reset the graph state
168
+
156
169
  # jit exec
157
- def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
170
+ def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
158
171
  # assign inputs
159
172
  for idx, offset, device, size, dtype in self.extra_view_inputs:
160
173
  input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
161
174
  for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
162
175
 
163
176
  # Condense the items into a graph executor.
164
- if JIT < 2 and not self._graphed:
165
- self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
166
- self._input_replace = get_input_replace(self._jit_cache, input_buffers)
167
- self._graphed = True
177
+ if self._first_run:
178
+ # allocate intermediates if freed
179
+ for ji in self.jit_cache:
180
+ for b in ji.bufs:
181
+ if b is not None: b.ensure_allocated()
182
+ # create graph if needed
183
+ if JIT < 2:
184
+ self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
185
+ self._input_replace = get_input_replace(self._jit_cache, input_buffers)
186
+ self._first_run = False
168
187
 
169
188
  if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
170
189
  for ei in self._jit_cache: ei.run(var_vals, jit=True)
@@ -172,13 +191,14 @@ class CapturedJit(Generic[ReturnType]):
172
191
  return self.ret
173
192
 
174
193
  def _prepare_jit_inputs(args, kwargs):
175
- input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
194
+ input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
176
195
  names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
177
196
  if tensors: Tensor.realize(*tensors)
178
- lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
179
- input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
197
+ # TODO: should we be unpacking multi here?
198
+ lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
199
+ input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
180
200
  assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
181
- st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
201
+ st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
182
202
  var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
183
203
  st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
184
204
  return input_buffers, var_vals, names, st_vars_dtype_device
@@ -214,9 +234,9 @@ class TinyJit(Generic[ReturnType]):
214
234
 
215
235
  # keep legacy code working
216
236
  @property
217
- def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
237
+ def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
218
238
  @property
219
- def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
239
+ def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
220
240
 
221
241
  def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
222
242
 
@@ -232,7 +252,7 @@ class TinyJit(Generic[ReturnType]):
232
252
  # jit capture
233
253
  assert self.fxn is not None
234
254
  if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
235
- self._jit_cache: List[ExecItem] = []
255
+ self._jit_cache: list[ExecItem] = []
236
256
  self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
237
257
  # TODO: should we always disable the memory planner here? it must be off for prune
238
258
  with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
@@ -249,7 +269,7 @@ class TinyJit(Generic[ReturnType]):
249
269
 
250
270
  # track inputs that are views of buffers
251
271
  # TODO: eventually expected_buffers should live in ExecItem
252
- extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
272
+ extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
253
273
  for item in jit_cache:
254
274
  for b in item.bufs:
255
275
  if b is not None and b._base is not None and b._base in input_buffers:
@@ -259,10 +279,7 @@ class TinyJit(Generic[ReturnType]):
259
279
  # prune independent kernels (optional)
260
280
  if self.prune:
261
281
  depends = set(input_buffers)
262
- for ei in jit_cache:
263
- if any(b in depends for b in ei.bufs):
264
- if isinstance(ei.prg, CompiledRunner):
265
- depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
282
+ update_depends(depends, jit_cache)
266
283
  pruned, onetime = partition(jit_cache,
267
284
  lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
268
285
  if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
@@ -275,7 +292,7 @@ class TinyJit(Generic[ReturnType]):
275
292
  # memory planning (optional)
276
293
  # Exclude buffers involved in transfer ops to preserve parallelism.
277
294
  noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
278
- assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
295
+ assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
279
296
  jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
280
297
 
281
298
  input_replace = get_input_replace(jit_cache, input_buffers)
tinygrad/engine/memory.py CHANGED
@@ -1,4 +1,3 @@
1
- from typing import List, Union, Tuple, Dict
2
1
  from collections import defaultdict
3
2
  from tinygrad.engine.schedule import ScheduleItem
4
3
  from tinygrad.device import Device, Buffer
@@ -7,7 +6,7 @@ from tinygrad.ops import Ops
7
6
 
8
7
  # **************** memory planning ****************
9
8
 
10
- def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
9
+ def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
11
10
  if NO_MEMORY_PLANNER: return {}
12
11
  first_appearance, last_appearance = {}, {}
13
12
  for i,u in enumerate(buffers):
@@ -18,7 +17,7 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
18
17
 
19
18
  # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
20
19
  # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
21
- free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
20
+ free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
22
21
  def find_replace_buffer(buf, st, en):
23
22
  key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
24
23
 
@@ -44,8 +43,8 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
44
43
  f"{len(ak)} -> {len(av)} bufs")
45
44
  return assigned
46
45
 
47
- def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
46
+ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
48
47
  # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
49
48
  assigned = _internal_memory_planner([si.bufs for si in schedule],
50
49
  noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
51
- return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
50
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
@@ -0,0 +1,162 @@
1
+ from __future__ import annotations
2
+ import functools, itertools, operator
3
+ from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
4
+ from tinygrad.ops import Ops, UOp, sint
5
+
6
+ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
7
+ assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
8
+ assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
9
+ n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
10
+ # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
11
+ # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
12
+ use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
13
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
14
+ if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
15
+
16
+ factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
17
+ base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
18
+ chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
19
+ chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
20
+ chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
21
+
22
+ # scatter-reduce
23
+ for step in range(n_lbs-1):
24
+ for i in range(len(chunks)):
25
+ src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
26
+ chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device))
27
+
28
+ # allgather
29
+ for step in range(n_lbs-1):
30
+ for i in range(len(chunks)):
31
+ src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
32
+ chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device)
33
+
34
+ # assemble chunks back
35
+ pads = [((s,numel-e),) for s,e in chunks]
36
+ return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
37
+
38
+ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
39
+ if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
40
+ return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
41
+
42
+ # ***** multi functions *****
43
+
44
+ from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
45
+
46
+ def alu_multi(root:UOp):
47
+ msrcs = root.src
48
+ assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
49
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
50
+
51
+ axis = root.axis
52
+ bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
53
+ srcs:list[list[UOp]] = []
54
+ not_all_real = not all(all(mlb.real) for mlb in msrcs)
55
+ new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
56
+ for mlb in msrcs:
57
+ if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
58
+ else:
59
+ assert axis is not None and bounds is not None
60
+ if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
61
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
62
+ new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
63
+ new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
64
+ return UOp.multi(*new_lbs, axis=axis, real=new_real)
65
+
66
+ def reduce_multi(root:UOp, multi:UOp):
67
+ op, axis = root.arg
68
+ if multi.axis is not None and multi.axis in axis:
69
+ # all-reduce on sharded axes
70
+ reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
71
+ # if all partitions are real, do all_reduce
72
+ if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis)
73
+ # only one partition is real, keep it
74
+ return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real)
75
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
76
+ return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real)
77
+
78
+ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
79
+ return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
80
+
81
+ def reshape_multi(root:UOp, multi:UOp):
82
+ arg = root.arg
83
+ if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
84
+ assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
85
+ assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
86
+ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
87
+ lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
88
+ return UOp.multi(*lbs, axis=new_axis, real=multi.real)
89
+
90
+ def expand_multi(root:UOp, multi:UOp):
91
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
92
+ assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
93
+ return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
94
+
95
+ def pad_multi(root:UOp, multi:UOp):
96
+ assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
97
+ # pad on shard axis -> fill others with zeros and set real to all True
98
+ if multi.axis is not None and root.arg[multi.axis] != (0,0):
99
+ # pad back to whole axis, remove real mask
100
+ assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
101
+ dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
102
+ assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
103
+ return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
104
+ return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
105
+
106
+ def permute_multi(root:UOp, multi:UOp):
107
+ # all permutes supported!
108
+ return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real)
109
+
110
+ def shrink_multi(root:UOp, multi:UOp):
111
+ assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
112
+ f"shrinking not supported for {root.arg=}"
113
+ if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
114
+ assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
115
+ "cannot shrink sharded and non-sharded axis at the same time"
116
+ # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
117
+ idx = multi.bounds.index(root.arg[multi.axis])
118
+ # zero out other lbs to not create lb reference
119
+ return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
120
+ axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
121
+ return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
122
+ axis=multi.axis, real=multi.real)
123
+
124
+ def flip_multi(root:UOp, multi:UOp):
125
+ assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
126
+ return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
127
+
128
+ def copy_multi(multi:UOp, device:UOp):
129
+ # if we already have a copy on the device, return that
130
+ if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
131
+ # copy lbs to device, pad to final shape, and sum
132
+ llbs:list[UOp] = []
133
+ for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
134
+ if not real: continue
135
+ pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
136
+ llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
137
+ return functools.reduce(operator.add, llbs)
138
+
139
+ def assign_multi(dest:UOp, src:UOp):
140
+ assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
141
+ return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
142
+
143
+ def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
144
+
145
+ # NOTE: this is the same pattern as Ops.UNROLL
146
+ multi_pm = PatternMatcher([
147
+ (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
148
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
149
+ (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
150
+ (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
151
+ (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
152
+ (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
153
+ (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
154
+ (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
155
+ (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
156
+ (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
157
+ (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
158
+ src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
159
+ ])
160
+
161
+ @track_rewrites(named=True)
162
+ def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}