tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
1
+ from typing import cast
2
+ import functools, itertools, operator
3
+ from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv, unwrap
4
+ from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, resolve
5
+ from tinygrad.device import Device
6
+
7
+ # *** allreduce implementation ***
8
+ def handle_allreduce_multirank(buf:UOp, red:UOp) -> UOp|None:
9
+ if not isinstance(buf.device, tuple): return None
10
+
11
+ # Group buffers
12
+ groups: dict[int|None, list[UOp]] = {}
13
+ for i,dev in enumerate(buf.device):
14
+ groups.setdefault(Device[dev].group_id, []).append(buf.mselect(i))
15
+
16
+ # Put reduce leader of each group first
17
+ reduce_leaders = set(getenv("REDUCE_LEADERS", "").split(","))
18
+ groups = {gid: sorted(bufs, key=lambda x: (x.device not in reduce_leaders, x.device)) for gid,bufs in groups.items()}
19
+
20
+ # Skip if only one group or if every group has only one buffer
21
+ if len(groups) <= 1 or not any(len(g) > 1 for g in groups.values()): return None
22
+
23
+ # Reduce inside each group
24
+ inner = [UOp(Ops.MSTACK, buf.dtype, tuple(bufs)).allreduce(red.arg, (cast(str, bufs[0].device),)).mselect(0) for bufs in groups.values()]
25
+
26
+ # Allreduce across groups
27
+ outer = UOp(Ops.MSTACK, buf.dtype, tuple(inner)).allreduce(red.arg, tuple(buf.device for buf in inner))
28
+
29
+ # Broadcast back to all devices in the group
30
+ gid2bid = {Device[device].group_id: i for i,device in enumerate(outer.device)}
31
+ return outer.mselect(gid2bid[Device[red.device].group_id]).copy_to_device(red.device) if not isinstance(red.device, tuple) else \
32
+ UOp(Ops.MSTACK, buf.dtype, tuple(outer.mselect(gid2bid[Device[device].group_id]).copy_to_device(device) for device in red.device))
33
+
34
+ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
35
+ if not isinstance(buf.device, tuple): return None
36
+ assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
37
+ n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
38
+ # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
39
+ # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
40
+ use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
41
+ if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
42
+
43
+ # contiguous before we copy it
44
+ buf = buf.contiguous()
45
+
46
+ # copy to all devices. if you shrink later, that'll be handled
47
+ if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
48
+ [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))])
49
+
50
+ # new ring reduce
51
+ factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
52
+ base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
53
+ chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
54
+ chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
55
+
56
+ # extract chunks and scatter-reduce
57
+ reduced_chunks = []
58
+ for i,(s,e) in enumerate(chunks):
59
+ chunk = buf.reshape((numel,)).shrink(((s,e),))
60
+ reduced_chunk = chunk
61
+ for step in range(n_lbs-1):
62
+ src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
63
+ # copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
64
+ reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \
65
+ .alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
66
+ reduced_chunks.append(reduced_chunk)
67
+
68
+ # allgather
69
+ copied_chunks = []
70
+ for i,c in enumerate(reduced_chunks):
71
+ this_chunk = [None] * len(buf.device)
72
+ this_chunk[(i+len(buf.device)-1)%n_lbs] = c
73
+ for step in range(n_lbs-1):
74
+ dest = (i+step)%n_lbs
75
+ this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
76
+ copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
77
+
78
+ # reassemble
79
+ pads = [((s,numel-e),) for s,e in chunks]
80
+ return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
81
+
82
+ # ***** multi rewrite MSELECT/MSTACK *****
83
+
84
+ def _replace_dnum(st, val):
85
+ # replace dnum in ShapeTracker with literal const for this mselect
86
+ if (dnums:=[x for x in st.vars() if x.op is Ops.DEFINE_VAR and x.arg[0] == '_device_num']):
87
+ assert len(dnums) == 1, f"view must have exactly 0 or 1 dnum, got {dnums}"
88
+ st = st.substitute({dnums[0]:dnums[0].const_like(val)})
89
+ return st
90
+
91
+ def mstack_reorder_view(ms:UOp):
92
+ args = [x.arg for x in ms.src]
93
+ if not all_same(args) or len([x for x in args[0].vars() if x.arg[0] == '_device_num']) != 0: return None
94
+ return UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).view(args[0])
95
+
96
+ def mstack_early_shrink(view:UOp, ms:UOp):
97
+ if resolve(prod(view.shape) >= prod(ms.shape)) or _replace_dnum(view.st, 0) == view.st: return None
98
+ ret = []
99
+ for i, x in enumerate(ms.src):
100
+ new_view = _replace_dnum(view.st, i)
101
+ if x.op is Ops.COPY:
102
+ # if src device doesn't have a renderer, we have to view after the copy
103
+ # TODO: a way to understand this
104
+ if x.src[0].device in {"DISK", "NPY"}:
105
+ ret.append(x.view(new_view))
106
+ else:
107
+ ret.append(x.src[0].view(new_view).copy_to_device(x.device))
108
+ else:
109
+ ret.append(x.view(new_view).contiguous())
110
+ return ms.replace(src=tuple(ret))
111
+
112
+ replace_allreduce = PatternMatcher([
113
+ (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce_multirank),
114
+ (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
115
+ # BROADCAST: explicitly expand broadcast copies and combine with MSTACK
116
+ (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
117
+ UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
118
+ # COPY_TO_ONE: if copying from multidevice to one, MSELECT the first (TODO: a little from each?)
119
+ (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
120
+ x.mselect(0).copy_to_device(c.device) if isinstance(c.device, str) and isinstance(x.device, tuple) else None),
121
+ # MSELECT on MSTACK is replaced with nothing
122
+ (UPat(Ops.MSELECT, src=(UPat(Ops.MSTACK, name="mstack"),), name="ms"), lambda mstack, ms: mstack.src[ms.arg]),
123
+ # MSELECT must select a base, if there are views apply them after selecting the base
124
+ (UPat(Ops.MSELECT, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"),), name="ms"), lambda ms, view, base:
125
+ base.mselect(ms.arg).view(_replace_dnum(unwrap(view.st), ms.arg))),
126
+ # move view through MSTACK
127
+ (UPat(Ops.MSTACK, src=UPat(Ops.VIEW), name="ms"), mstack_reorder_view),
128
+ # move shrink before MSTACK
129
+ (UPat(Ops.VIEW, src=(UPat(Ops.MSTACK, name="ms"),), name="view"), mstack_early_shrink),
130
+ ])
131
+
132
+ # ***** multi functions *****
133
+
134
+ def alu_multi(root:UOp):
135
+ msrcs = root.src
136
+ assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
137
+ axis = root.axis
138
+ assert axis is not None
139
+
140
+ srcs = []
141
+ for mlb in msrcs:
142
+ if mlb.axis == axis:
143
+ # same axis, just copy through
144
+ assert mlb.op is Ops.MULTI
145
+ srcs.append(mlb.src[0])
146
+ elif mlb.axis is None:
147
+ # no axis, shard it
148
+ assert mlb.op is not Ops.MULTI
149
+ srcs.append(mlb._shard(axis))
150
+ else:
151
+ # axis mismatch, unshard it, send it to all devices, and shard it correctly
152
+ assert mlb.op is Ops.MULTI
153
+ srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
154
+ return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
155
+
156
+ def reduce_multi(root:UOp, multi:UOp):
157
+ op, axis = root.arg
158
+ if multi.axis is not None and multi.axis in axis:
159
+ # all-reduce on sharded axes
160
+ return multi.src[0].r(op, axis).allreduce(op, multi.device)
161
+ # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
162
+ return multi.src[0].r(op, axis).multi(axis=multi.axis)
163
+
164
+ def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
165
+ return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
166
+
167
+ def reshape_multi(root:UOp, multi:UOp):
168
+ arg = root.arg
169
+ if (new_axis:=root.axis) is None: return multi.src[0].reshape(arg).multi(new_axis)
170
+ assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
171
+ assert prod(multi.src[0].shape[multi.axis:])%prod(arg[new_axis+1:]) == 0, f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
172
+ new_shape_axis = prod(multi.src[0].shape[multi.axis:]) // prod(arg[new_axis+1:])
173
+ return multi.src[0].reshape(tuple(s if a!=new_axis else new_shape_axis for a,s in enumerate(arg))).multi(new_axis)
174
+
175
+ def expand_multi(root:UOp, multi:UOp):
176
+ # NOTE: this assert isn't needed, sharded axis can have dim 1
177
+ assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
178
+ return multi.src[0].expand(_shape_to_single_shard(multi.axis, root.arg, multi.src[0])).multi(multi.axis)
179
+
180
+ def pad_multi(root:UOp, multi:UOp):
181
+ assert multi.axis is None or root.arg[multi.axis] == (0,0), f"padding not supported for {root.arg=}"
182
+ return multi.src[0].pad(root.arg).multi(multi.axis)
183
+
184
+ def permute_multi(root:UOp, multi:UOp):
185
+ # all permutes supported!
186
+ return multi.src[0].permute(root.arg).multi(root.axis)
187
+
188
+ def shrink_multi(root:UOp, multi:UOp):
189
+ assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
190
+ f"shrinking not supported for {root.arg=}"
191
+ if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
192
+ assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
193
+ "cannot shrink sharded and non-sharded axis at the same time"
194
+ # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
195
+ # we just copy it to all the devices, no real. this will be optimized out later
196
+ return multi.src[0].copy_to_device(multi.device, arg=multi.bounds.index(root.arg[multi.axis]))
197
+ return multi.src[0].shrink(tuple((0, multi.src[0].shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))).multi(multi.axis)
198
+
199
+ def flip_multi(root:UOp, multi:UOp):
200
+ assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
201
+ return multi.src[0].flip(root.arg).multi(multi.axis)
202
+
203
+ # from multiple devices -> one
204
+ def copy_multi(multi:UOp, device:UOp):
205
+ assert multi.axis is not None, "all multi ops have axis"
206
+ return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
207
+
208
+ def assign_multi(dest:UOp, src:UOp):
209
+ if dest.axis != src.axis: raise RuntimeError(f"axis must match in assign {dest.axis} != {src.axis}")
210
+ return dest.src[0].assign(src.src[0]).multi(src.axis)
211
+
212
+ def passthrough_multi(root:UOp, multi:UOp):
213
+ return root.replace(src=(multi.src[0],)).multi(multi.axis)
214
+
215
+ # NOTE: this is the same pattern as Ops.UNROLL
216
+ multi_pm = PatternMatcher([
217
+ (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
218
+ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
219
+ (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
220
+ (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
221
+ (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
222
+ (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
223
+ (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
224
+ (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
225
+ (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
226
+ (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
227
+ (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
228
+ lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
229
+ (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE),
230
+ src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
231
+ ])+replace_allreduce
@@ -2,58 +2,55 @@
2
2
  from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  import functools
5
- from typing import Optional, Callable
5
+ from typing import Callable
6
6
  from tinygrad.helpers import merge_dicts, getenv
7
- from tinygrad.shape.view import View, strides_for_shape, unravel
7
+ from tinygrad.shape.view import View, unravel
8
8
  from tinygrad.dtype import dtypes
9
- from tinygrad.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context
10
- from tinygrad.codegen.symbolic import sym, split_uop, symbolic_flat, uop_given_valid, simplify_valid
11
-
12
- def overflow(u: UOp): return u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int)
9
+ from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
10
+ from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
13
11
 
14
12
  # If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
15
13
  # or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
16
- def upcast(u: UOp):
17
- srcs = tuple(upcast(_src) for _src in u.src)
18
- if u.dtype.scalar() is dtypes.int:
19
- dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
20
- upcasted = u.replace(dtype=dtype, src=tuple([_src.cast(dtype) for _src in srcs]))
21
- if overflow(u): return upcasted
22
- # Check the original src, new srcs has Ops.CAST whose vmin, vmax change the real bounds
23
- # Cast back is required because if the node is in range, siblings would never be upcasted
24
- if any((overflow(src) for src in u.src)): return upcasted.cast(u.dtype)
25
- return u.replace(src=tuple(srcs))
26
-
27
- # pooling op may overflow before folding causing unnecessary upcast
28
- def folded_upcast(u: UOp):
29
- with Context(TRACK_MATCH_STATS=0):
30
- return upcast(graph_rewrite(u, sym, {}))
31
-
32
- @functools.lru_cache(None)
33
- def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
14
+ def handle_upcast(u: UOp) -> UOp|None:
15
+ dtype = dtypes.int64.vec(u.dtype.count) if u.dtype.count > 1 else dtypes.int64
16
+ # check for overflow, upcast this to int64
17
+ if u.vmax > dtypes.max(dtypes.int) or u.vmin < dtypes.min(dtypes.int):
18
+ return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src]))
19
+ # if any inputs are int64 and this *doesn't* overflow, cast back to int
20
+ if any(x.dtype == dtypes.int64 for x in u.src):
21
+ return u.replace(dtype=dtype, src=tuple([x.cast(dtype) for x in u.src])).cast(u.dtype)
22
+ return None
23
+ pm_upcast = PatternMatcher([(UPat(GroupOp.ALU, dtype=dtypes.int, name="u"), handle_upcast),])
24
+
25
+ @functools.cache
26
+ def views_to_indexed_uops(views: tuple[View, ...], _idxs:tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
34
27
  idx, valid = views[-1].to_indexed_uops(_idxs)
35
28
  for view in reversed(views[0:-1]):
36
29
  view = view.minify()
37
30
  idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
38
- return idx, valid
39
-
40
- @functools.lru_cache(None)
41
- def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
31
+ with Context(TRACK_MATCH_STATS=0):
32
+ # symbolic
33
+ idx, valid = graph_rewrite(UOp.sink(idx, valid), symbolic_flat, name="indexing sym @ 1").src
34
+ # simplify
35
+ if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
36
+ if (newidx:=uop_given_valid(valid, idx)) is not None: idx = newidx
37
+ # symbolic again, upcast if needed
38
+ return graph_rewrite(UOp.sink(idx, valid), symbolic_flat+pm_upcast, name="indexing sym @ 2").src
39
+
40
+ @functools.cache
41
+ def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[sint|None, ...]:
42
42
  # NOTE: if a stride is not always valid, it will be None
43
43
  if len(views) == 1 and views[-1].mask is None: return views[-1].strides
44
- ret: list[Optional[sint]] = [None] * len(views[-1].shape)
45
- idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
46
- # TODO: always apply these in to_indexed_uops?
47
- if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
48
- if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
44
+ ret: list[sint|None] = [None] * len(views[-1].shape)
45
+ idx, valid = views_to_indexed_uops(views)
49
46
  for c in split_uop(idx, Ops.ADD):
50
47
  if c.op is Ops.RANGE: ret[c.arg] = 1
51
48
  if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
52
49
  if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
53
- used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
50
+ used_ranges = [x.arg for x in idx.toposort() if x.op is Ops.RANGE]
54
51
  ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
55
52
  if not ignore_valid:
56
- for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
53
+ for masked_axis in [x.arg for x in valid.toposort() if x.op is Ops.RANGE]: ret[masked_axis] = None
57
54
  return tuple(ret)
58
55
 
59
56
  @dataclass(frozen=True, order=True)
@@ -65,7 +62,7 @@ class ShapeTracker:
65
62
  for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
66
63
  return ret
67
64
 
68
- def invert(self, out_shape:tuple[sint, ...]) -> Optional[ShapeTracker]:
65
+ def invert(self, out_shape:tuple[sint, ...]) -> ShapeTracker|None:
69
66
  inverted_views:list[View] = []
70
67
  for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
71
68
  if (inverted:= v.invert(s)) is None: return None
@@ -73,14 +70,11 @@ class ShapeTracker:
73
70
  return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
74
71
 
75
72
  @staticmethod
76
- def from_shape(shape:tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
73
+ def from_shape(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None) -> ShapeTracker: return ShapeTracker((View.create(shape, strides),))
77
74
 
78
75
  @property
79
76
  def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
80
77
 
81
- @property
82
- def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
83
-
84
78
  @property
85
79
  def shape(self) -> tuple[sint, ...]: return self.views[-1].shape
86
80
 
@@ -89,10 +83,8 @@ class ShapeTracker:
89
83
 
90
84
  def reduce(self, axis:tuple[int, ...]) -> tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
91
85
 
92
- def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
93
- def to_indexed_uops(self, _idxs:Optional[list[UOp]|tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
94
- idx, valid = views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
95
- return folded_upcast(idx), folded_upcast(valid)
86
+ def to_indexed_uops(self, _idxs:list[UOp]|tuple[UOp, ...]|None=None) -> tuple[UOp, UOp]:
87
+ return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
96
88
 
97
89
  # upper bound on buffer size required to fit this shapetracker
98
90
  def real_size(self) -> int:
@@ -111,14 +103,16 @@ class ShapeTracker:
111
103
  unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
112
104
  if all(len(x) == 0 for x in var_vals): return self, {}
113
105
  return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
106
+ def substitute(self, dvars:dict[UOp, UOp]): return ShapeTracker(tuple(x.substitute(dvars) for x in self.views))
114
107
 
115
- def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
108
+ def real_strides(self, ignore_valid=False) -> tuple[sint|None, ...]:
109
+ with Context(TRACK_MATCH_STATS=0): return views_to_real_strides(self.views, ignore_valid)
116
110
  def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
117
111
 
118
112
  def axis_is_masked(self, axis:int) -> bool:
119
113
  with Context(TRACK_MATCH_STATS=0):
120
114
  _, valid = self.to_indexed_uops()
121
- return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
115
+ return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort() if x.op is Ops.RANGE]
122
116
 
123
117
  def simplify(self) -> ShapeTracker:
124
118
  if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: