tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/jit.py CHANGED
@@ -1,14 +1,12 @@
1
- from __future__ import annotations
2
- from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
1
+ from typing import TypeVar, Generic, Callable, Union, cast, Optional, Any
3
2
  import functools, collections
4
3
  from tinygrad.tensor import Tensor
5
- from tinygrad.engine.lazy import LazyBuffer
6
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
4
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
7
5
  from tinygrad.device import Buffer, Compiled, Device
8
6
  from tinygrad.dtype import DType
9
- from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
7
+ from tinygrad.ops import UOp, Variable, sym_infer, Ops
10
8
  from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
9
+ from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
12
10
  from tinygrad.engine.memory import _internal_memory_planner
13
11
  from tinygrad.nn.state import get_parameters
14
12
  from dataclasses import dataclass
@@ -16,21 +14,22 @@ from weakref import WeakKeyDictionary
16
14
 
17
15
  class GraphException(Exception): pass
18
16
 
19
- def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], max_batch_size=0) -> List[ExecItem]:
17
+ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int], max_batch_size=0) -> list[ExecItem]:
20
18
  # Split JIT cache into batches for faster graph execution.
21
19
  # This allows the accelerator to run some batches while subsequent graphs are still being updated.
22
- graphed_jit_cache: List[ExecItem] = []
23
- current_batch: List[ExecItem] = []
20
+ graphed_jit_cache: list[ExecItem] = []
21
+ current_batch: list[ExecItem] = []
24
22
  current_device: Optional[Compiled] = None
25
23
 
26
24
  def flush_batch():
27
25
  nonlocal current_batch, current_device, max_batch_size
28
26
  try:
29
- if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
27
+ if current_device is None: raise GraphException("no device for graph")
28
+ if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
30
29
  graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
31
30
  # clear jit inputs to allow their memory to be freed/reused
32
31
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
33
- graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
32
+ graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Optional[Buffer]], input_rawbuffers)))
34
33
  max_batch_size *= 2
35
34
  if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
36
35
  except GraphException as e:
@@ -40,9 +39,9 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
40
39
  current_device = None
41
40
 
42
41
  for ji in jit_cache:
43
- if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
42
+ if isinstance(ji.prg, ViewOp): continue
44
43
  ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
45
- if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
44
+ if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
46
45
  elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
47
46
  ji_graph_dev = Device[ji.bufs[0].device]
48
47
 
@@ -61,24 +60,21 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
61
60
  if len(current_batch) > 0: flush_batch()
62
61
  return graphed_jit_cache
63
62
 
64
- def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
65
- input_replace: Dict[Tuple[int, int], int] = {}
63
+ def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
64
+ input_replace: dict[tuple[int, int], int] = {}
66
65
  for j,ji in enumerate(jit_cache):
67
66
  for i,a in enumerate(ji.bufs):
68
67
  if a in input_rawbuffers:
69
68
  input_replace[(j,i)] = input_rawbuffers.index(a)
70
69
  return input_replace
71
70
 
72
- class GraphRunner(Runner): # pylint: disable=abstract-method
73
- def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
74
- self.jit_cache = jit_cache
75
- self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
76
- self.var_vals_replace:Dict[int, List[int]] = {}
77
- self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
78
-
79
- op_estimate: sint = 0
80
- mem_estimate: sint = 0
81
- lds_estimate: sint = 0
71
+ class GraphRunner(Runner):
72
+ def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
73
+ self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
74
+ self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
75
+ self.var_vals_replace:dict[int, list[int]] = {}
76
+ self.launch_dims_replace:dict[int, tuple[Optional[int], Optional[int]]] = {}
77
+ self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
82
78
 
83
79
  def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
84
80
 
@@ -87,33 +83,35 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
87
83
  [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
88
84
  def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
89
85
 
86
+ estimates = Estimates()
90
87
  for j,ji in enumerate(jit_cache):
91
- op_estimate += ji.prg.op_estimate
92
- mem_estimate += ji.prg.mem_estimate
93
- lds_estimate += ji.prg.lds_estimate
88
+ estimates += ji.prg.estimates
94
89
  if isinstance(ji.prg, CompiledRunner):
95
90
  if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
96
91
 
97
92
  global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
98
- if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
93
+ if global_dim_idx is not None or local_dim_idx is not None:
94
+ self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
95
+ assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
96
+ self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
99
97
 
100
98
  # used in MultiGraphRunner. the ints are id() of _bufs
101
- self.w_dependency_map: Dict[int, Any] = {}
102
- self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
99
+ self.w_dependency_map: dict[int, Any] = {}
100
+ self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
103
101
 
104
- super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
105
- ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
102
+ super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
106
103
 
107
- def updated_vars(self, var_vals: Dict[Variable, int]):
104
+ def updated_vars(self, var_vals: dict[Variable, int]):
108
105
  vals = [var_vals[v] for v in self.vars]
109
106
  for j, vidxs in self.var_vals_replace.items():
110
107
  for i, v in enumerate(vidxs): yield j, i, vals[v]
111
108
 
112
- def updated_launch_dims(self, var_vals: Dict[Variable, int]):
109
+ def updated_launch_dims(self, var_vals: dict[Variable, int]):
113
110
  dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
114
- for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
111
+ for j, (gl, lc) in self.launch_dims_replace.items():
112
+ yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
115
113
 
116
- def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
114
+ def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
117
115
  # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
118
116
  # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
119
117
  wait_nodes = []
@@ -128,43 +126,65 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
128
126
  return list({id(x):x for x in wait_nodes}.values())
129
127
 
130
128
  # a marker for your graph supporting multiple devices of the same type
131
- class MultiGraphRunner(GraphRunner): pass # pylint: disable=abstract-method
129
+ class MultiGraphRunner(GraphRunner): pass
130
+
131
+ def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
132
+ for ei in jit_cache:
133
+ if any(b in depends for b in ei.bufs):
134
+ if isinstance(ei.prg, CompiledRunner):
135
+ depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
136
+ if isinstance(ei.prg, (BufferCopy, BufferXfer)):
137
+ depends.add(cast(Buffer, ei.bufs[0]))
132
138
 
133
139
  ReturnType = TypeVar('ReturnType')
134
140
  @dataclass
135
141
  class CapturedJit(Generic[ReturnType]):
136
142
  ret: Any # includes the Tensors or any other returned object
137
- jit_cache: List[ExecItem]
138
- input_replace: Dict[Tuple[int, int], int]
139
- extra_view_inputs: List[Tuple[int, int, str, int, DType]]
140
- expected_names: List[Union[int, str]]
141
- expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
143
+ jit_cache: list[ExecItem]
144
+ input_replace: dict[tuple[int, int], int]
145
+ extra_view_inputs: list[tuple[int, int, str, int, DType]]
146
+ expected_names: list[Union[int, str]]
147
+ expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
142
148
 
143
149
  def __reduce__(self):
150
+ # TODO: free_intermediates here?
144
151
  return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
145
152
  self.expected_names, self.expected_st_vars_dtype_device)
146
153
 
147
154
  def __post_init__(self):
148
- self._jit_cache: List[ExecItem] = self.jit_cache
149
- self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
150
- self._graphed = False
155
+ self._jit_cache: list[ExecItem] = self.jit_cache
156
+ self._input_replace: dict[tuple[int, int], int] = self.input_replace
157
+ self._first_run = True
151
158
  self._clear_inputs()
152
159
 
153
160
  def _clear_inputs(self):
154
161
  for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
155
162
 
163
+ def free_intermediates(self):
164
+ depends: set[Buffer|None] = set([None])
165
+ update_depends(depends, self.jit_cache)
166
+ for b in depends:
167
+ if b is not None: b.deallocate()
168
+ self.__post_init__() # reset the graph state
169
+
156
170
  # jit exec
157
- def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
171
+ def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) -> ReturnType:
158
172
  # assign inputs
159
173
  for idx, offset, device, size, dtype in self.extra_view_inputs:
160
174
  input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
161
175
  for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
162
176
 
163
177
  # Condense the items into a graph executor.
164
- if JIT < 2 and not self._graphed:
165
- self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
166
- self._input_replace = get_input_replace(self._jit_cache, input_buffers)
167
- self._graphed = True
178
+ if self._first_run:
179
+ # allocate intermediates if freed
180
+ for ji in self.jit_cache:
181
+ for b in ji.bufs:
182
+ if b is not None: b.ensure_allocated()
183
+ # create graph if needed
184
+ if JIT < 2:
185
+ self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=getenv("JIT_BATCH_SIZE", 32))
186
+ self._input_replace = get_input_replace(self._jit_cache, input_buffers)
187
+ self._first_run = False
168
188
 
169
189
  if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
170
190
  for ei in self._jit_cache: ei.run(var_vals, jit=True)
@@ -172,13 +192,14 @@ class CapturedJit(Generic[ReturnType]):
172
192
  return self.ret
173
193
 
174
194
  def _prepare_jit_inputs(args, kwargs):
175
- input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
195
+ input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
176
196
  names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
177
- if tensors: Tensor.realize(*tensors)
178
- lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
179
- input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
197
+ if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
198
+ # TODO: should we be unpacking multi here?
199
+ lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
200
+ input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
180
201
  assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
181
- st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
202
+ st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
182
203
  var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
183
204
  st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
184
205
  return input_buffers, var_vals, names, st_vars_dtype_device
@@ -214,9 +235,9 @@ class TinyJit(Generic[ReturnType]):
214
235
 
215
236
  # keep legacy code working
216
237
  @property
217
- def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
238
+ def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
218
239
  @property
219
- def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
240
+ def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
220
241
 
221
242
  def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
222
243
 
@@ -232,7 +253,7 @@ class TinyJit(Generic[ReturnType]):
232
253
  # jit capture
233
254
  assert self.fxn is not None
234
255
  if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
235
- self._jit_cache: List[ExecItem] = []
256
+ self._jit_cache: list[ExecItem] = []
236
257
  self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
237
258
  # TODO: should we always disable the memory planner here? it must be off for prune
238
259
  with Context(BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
@@ -249,7 +270,7 @@ class TinyJit(Generic[ReturnType]):
249
270
 
250
271
  # track inputs that are views of buffers
251
272
  # TODO: eventually expected_buffers should live in ExecItem
252
- extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
273
+ extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
253
274
  for item in jit_cache:
254
275
  for b in item.bufs:
255
276
  if b is not None and b._base is not None and b._base in input_buffers:
@@ -259,10 +280,7 @@ class TinyJit(Generic[ReturnType]):
259
280
  # prune independent kernels (optional)
260
281
  if self.prune:
261
282
  depends = set(input_buffers)
262
- for ei in jit_cache:
263
- if any(b in depends for b in ei.bufs):
264
- if isinstance(ei.prg, CompiledRunner):
265
- depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
283
+ update_depends(depends, jit_cache)
266
284
  pruned, onetime = partition(jit_cache,
267
285
  lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
268
286
  if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
@@ -275,7 +293,7 @@ class TinyJit(Generic[ReturnType]):
275
293
  # memory planning (optional)
276
294
  # Exclude buffers involved in transfer ops to preserve parallelism.
277
295
  noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
278
- assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
296
+ assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
279
297
  jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
280
298
 
281
299
  input_replace = get_input_replace(jit_cache, input_buffers)
tinygrad/engine/memory.py CHANGED
@@ -1,4 +1,3 @@
1
- from typing import List, Union, Tuple, Dict
2
1
  from collections import defaultdict
3
2
  from tinygrad.engine.schedule import ScheduleItem
4
3
  from tinygrad.device import Device, Buffer
@@ -7,7 +6,7 @@ from tinygrad.ops import Ops
7
6
 
8
7
  # **************** memory planning ****************
9
8
 
10
- def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
9
+ def _internal_memory_planner(buffers:list[list[Buffer]|tuple[Buffer, ...]], noopt_buffers=None, debug_prefix="") -> dict[Buffer, Buffer]:
11
10
  if NO_MEMORY_PLANNER: return {}
12
11
  first_appearance, last_appearance = {}, {}
13
12
  for i,u in enumerate(buffers):
@@ -18,7 +17,7 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
18
17
 
19
18
  # Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
20
19
  # Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
21
- free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
20
+ free_segs: dict[tuple, list[tuple[int, int, Buffer]]] = defaultdict(list) # dict[buffer key, tuple[start, end, buffer to reuse on the seg]]
22
21
  def find_replace_buffer(buf, st, en):
23
22
  key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
24
23
 
@@ -44,8 +43,8 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
44
43
  f"{len(ak)} -> {len(av)} bufs")
45
44
  return assigned
46
45
 
47
- def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
46
+ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
48
47
  # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
49
48
  assigned = _internal_memory_planner([si.bufs for si in schedule],
50
49
  noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
51
- return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
50
+ return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
@@ -0,0 +1,161 @@
1
+ import functools, itertools, operator
2
+ from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
3
+ from tinygrad.ops import Ops, UOp, sint
4
+
5
+ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
6
+ assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
7
+ assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
8
+ n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
9
+ # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
10
+ # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
11
+ use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
12
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
13
+ if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
14
+
15
+ factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
16
+ base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
17
+ chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
18
+ chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
19
+ chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
20
+
21
+ # scatter-reduce
22
+ for step in range(n_lbs-1):
23
+ for i in range(len(chunks)):
24
+ src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
25
+ chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device))
26
+
27
+ # allgather
28
+ for step in range(n_lbs-1):
29
+ for i in range(len(chunks)):
30
+ src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
31
+ chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device)
32
+
33
+ # assemble chunks back
34
+ pads = [((s,numel-e),) for s,e in chunks]
35
+ return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
36
+
37
+ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
38
+ if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
39
+ return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
40
+
41
+ # ***** multi functions *****
42
+
43
+ from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
44
+
45
+ def alu_multi(root:UOp):
46
+ msrcs = root.src
47
+ assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
48
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
49
+
50
+ axis = root.axis
51
+ bounds = dedup([x.bounds for x in root.src if x.axis == axis])[-1] if axis is not None else None
52
+ srcs:list[list[UOp]] = []
53
+ not_all_real = not all(all(mlb.real) for mlb in msrcs)
54
+ new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
55
+ for mlb in msrcs:
56
+ if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
57
+ else:
58
+ assert axis is not None and bounds is not None
59
+ if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
60
+ else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
61
+ new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
62
+ new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
63
+ return UOp.multi(*new_lbs, axis=axis, real=new_real)
64
+
65
+ def reduce_multi(root:UOp, multi:UOp):
66
+ op, axis = root.arg
67
+ if multi.axis is not None and multi.axis in axis:
68
+ # all-reduce on sharded axes
69
+ reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
70
+ # if all partitions are real, do all_reduce
71
+ if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=root.axis)
72
+ # only one partition is real, keep it
73
+ return UOp.multi(*reduced_parts, axis=root.axis, real=multi.real)
74
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
75
+ return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=root.axis, real=multi.real)
76
+
77
+ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
78
+ return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
79
+
80
+ def reshape_multi(root:UOp, multi:UOp):
81
+ arg = root.arg
82
+ if (new_axis:=root.axis) is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=new_axis, real=multi.real)
83
+ assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
84
+ assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
85
+ f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
86
+ lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
87
+ return UOp.multi(*lbs, axis=new_axis, real=multi.real)
88
+
89
+ def expand_multi(root:UOp, multi:UOp):
90
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
91
+ assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
92
+ return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
93
+
94
+ def pad_multi(root:UOp, multi:UOp):
95
+ assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
96
+ # pad on shard axis -> fill others with zeros and set real to all True
97
+ if multi.axis is not None and root.arg[multi.axis] != (0,0):
98
+ # pad back to whole axis, remove real mask
99
+ assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
100
+ dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
101
+ assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
102
+ return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
103
+ return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
104
+
105
+ def permute_multi(root:UOp, multi:UOp):
106
+ # all permutes supported!
107
+ return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.axis, real=multi.real)
108
+
109
+ def shrink_multi(root:UOp, multi:UOp):
110
+ assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
111
+ f"shrinking not supported for {root.arg=}"
112
+ if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
113
+ assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
114
+ "cannot shrink sharded and non-sharded axis at the same time"
115
+ # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
116
+ idx = multi.bounds.index(root.arg[multi.axis])
117
+ # zero out other lbs to not create lb reference
118
+ return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
119
+ axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
120
+ return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
121
+ axis=multi.axis, real=multi.real)
122
+
123
+ def flip_multi(root:UOp, multi:UOp):
124
+ assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
125
+ return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
126
+
127
+ def copy_multi(multi:UOp, device:UOp):
128
+ # if we already have a copy on the device, return that
129
+ if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
130
+ # copy lbs to device, pad to final shape, and sum
131
+ llbs:list[UOp] = []
132
+ for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
133
+ if not real: continue
134
+ pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
135
+ llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
136
+ return functools.reduce(operator.add, llbs)
137
+
138
+ def assign_multi(dest:UOp, src:UOp):
139
+ assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
140
+ return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
141
+
142
+ def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
143
+
144
+ # NOTE: this is the same pattern as Ops.UNROLL
145
+ multi_pm = PatternMatcher([
146
+ (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
147
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
148
+ (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
149
+ (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
150
+ (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
151
+ (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
152
+ (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
153
+ (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
154
+ (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
155
+ (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
156
+ (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
157
+ src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
158
+ ])
159
+
160
+ @track_rewrites(named=True)
161
+ def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}