tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/ops.py DELETED
@@ -1,1003 +0,0 @@
1
- from __future__ import annotations
2
- from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
3
- import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
4
- from enum import auto, IntEnum, Enum
5
- from dataclasses import dataclass, field
6
- from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
7
- from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
8
- from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
9
- if TYPE_CHECKING:
10
- from tinygrad.shape.shapetracker import ShapeTracker
11
- from tinygrad.device import Buffer
12
-
13
- # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
14
- class FastEnum(IntEnum):
15
- def __str__(self): return Enum.__str__(self)
16
- @staticmethod
17
- def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
18
-
19
- class SimpleMathTrait:
20
- # required to implement
21
- def alu(self:T, arg:Ops, *src) -> T: raise NotImplementedError
22
- def const_like(self:T, b:ConstLike) -> T: raise NotImplementedError
23
-
24
- # great functions you get!
25
- def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
26
- def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
27
- def logical_not(self): return self.ne(True)
28
- def neg(self):
29
- if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
30
- return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
31
- def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
32
- def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
33
- def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse)
34
- def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
35
- def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
36
- def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
37
- def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
38
- def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
39
- def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
40
-
41
- def __neg__(self): return self.neg()
42
-
43
- def __add__(self, x): return self.add(x)
44
- def __sub__(self, x): return self.sub(x)
45
- def __mul__(self, x): return self.mul(x)
46
- def __truediv__(self, x): return self.div(x)
47
- def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
48
- def __mod__(self, x): return self.mod(x)
49
- def __and__(self, x): return self.bitwise_and(x)
50
- def __or__(self, x): return self.bitwise_or(x)
51
- def __xor__(self, x): return self.xor(x)
52
-
53
- def __radd__(self, x): return self.add(x, True)
54
- def __rsub__(self, x): return self.sub(x, True)
55
- def __rmul__(self, x): return self.mul(x, True)
56
- def __rtruediv__(self, x): return self.div(x, True)
57
- def __rfloordiv__(self, x): return self.idiv(x, True)
58
- def __rand__(self, x): return self.bitwise_and(x, True)
59
- def __ror__(self, x): return self.bitwise_or(x, True)
60
- def __rxor__(self, x): return self.xor(x, True)
61
- def __rmod__(self, x): return self.mod(x, True)
62
-
63
- def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
64
- def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
65
- def __ge__(self, x): return (self < x).logical_not()
66
- def __le__(self, x): return (self > x).logical_not()
67
-
68
- def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
69
- def eq(self, x): return self.ne(x).logical_not()
70
- def __ne__(self, x): return self.ne(x)
71
- # NOTE: __eq__ isn't overridden, and means the same thing as is by default
72
-
73
- class MathTrait(SimpleMathTrait):
74
- # TODO: move to Tensor when new backward is done
75
- def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
76
- def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
77
- def __lshift__(self, x): return self.lshift(x)
78
- def __rshift__(self, x): return self.rshift(x)
79
- def __rlshift__(self, x): return self.lshift(x, True)
80
- def __rrshift__(self, x): return self.rshift(x, True)
81
-
82
- def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
83
- def minimum(self, x): return -(-self).maximum(-x)
84
- def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
85
- def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
86
- def reciprocal(self): return self.alu(Ops.RECIP)
87
- def sqrt(self): return self.alu(Ops.SQRT)
88
- def sin(self): return self.alu(Ops.SIN)
89
- def log2(self): return self.alu(Ops.LOG2)
90
- def exp2(self): return self.alu(Ops.EXP2)
91
- def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
92
-
93
- # the order of these Ops controls the order of the toposort
94
- class Ops(FastEnum):
95
- # uops that aren't rendered
96
- NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
97
-
98
- # TODO: empty continues to exist because of tensor
99
- EMPTY = auto()
100
-
101
- # MetaOps
102
- COPY = auto(); BUFFER_VIEW = auto() # noqa: E702
103
-
104
- # blocks in linearizer
105
- BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
106
-
107
- # movement ops!
108
- RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
109
-
110
- # misc ops
111
- UNROLL = auto(); CONTRACT = auto() # noqa: E702
112
- VIEW = auto(); DEFINE_GLOBAL = auto(); BUFFER = auto() # noqa: E702
113
- DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
114
- VALID = auto(); SPECIAL = auto(); NOOP = auto() # noqa: E702
115
-
116
- # reduce
117
- REDUCE_AXIS = auto()
118
-
119
- # helper ops
120
- GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
121
-
122
- # UnaryOps
123
- CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
124
-
125
- # load/store before math
126
- LOAD = auto(); STORE = auto() # noqa: E702
127
-
128
- # early INDEX
129
- INDEX = auto()
130
-
131
- # math ops
132
- WMMA = auto()
133
-
134
- # BinaryOps
135
- ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
136
- SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
137
-
138
- # TernaryOps
139
- WHERE = auto(); MULACC = auto() # noqa: E702
140
-
141
- # assignment ops
142
- ASSIGN = auto()
143
- BIND = auto()
144
-
145
- # control flow ops
146
- BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
147
-
148
- # consts last!
149
- VCONST = auto(); CONST = auto() # noqa: E702
150
-
151
- # device
152
- DEVICE = auto()
153
- MULTI = auto()
154
- CUSTOM = auto()
155
-
156
- class GroupOp:
157
- Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
158
- Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
159
- Ops.SUB, Ops.FDIV, Ops.POW}
160
- Ternary = {Ops.WHERE, Ops.MULACC}
161
- ALU = set.union(Unary, Binary, Ternary)
162
-
163
- Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
164
- Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
165
-
166
- Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
167
- Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
168
-
169
- # BinaryOps that can be flipped
170
- Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.XOR, Ops.AND, Ops.OR}
171
-
172
- # BinaryOps where f(f(a,b),c) = f(a,f(b,c))
173
- Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
174
-
175
- # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
176
- Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
177
-
178
- # do not preserve f(0) = 0
179
- UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
180
-
181
- All = set(Ops)
182
-
183
- # some BUFFER ops can be processed with only a view
184
- view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
185
-
186
- # https://en.wikipedia.org/wiki/Identity_element
187
- def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
188
-
189
- def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
190
- if u.op in GroupOp.UnsafePad: return False
191
- if u in edges or u in cache: return True
192
- cache[u] = None
193
- return all(can_pad(x.base, edges, cache) for x in u.src)
194
-
195
- # With True as the default, this matches the old symbolic behavior
196
- def resolve(x:UOp|bool, default:bool=True):
197
- if isinstance(x, bool): return x
198
- assert x.dtype == dtypes.bool, "UOp in resolve must be bool"
199
- # NOTE: generating the text for the exception is expensive, so we do this
200
- return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default
201
-
202
- # smax/smin are replacements for max/min that preserve symbolic
203
- def _suop(lst, uop_fxn, python_fxn):
204
- uops, nums = partition(lst, lambda x: isinstance(x, UOp))
205
- return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
206
- def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
207
- def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
208
-
209
- def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
210
- def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
211
-
212
- # used for UOp and UPat
213
- def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
214
- def dfs(x:Any, cache:dict):
215
- for s in srcfn(x) or []:
216
- cache.setdefault(s, [len(cache), 0, False])[1] += 1
217
- if cache[s][1] == 1: dfs(s, cache)
218
- if cache is None: dfs(x, cache:={})
219
- if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
220
- cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
221
- return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
222
-
223
- class UOpMetaClass(type):
224
- ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
225
- def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer:Buffer|None=None):
226
- if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
227
- UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
228
- for s in src: s.children.add(ref)
229
- # NOTE: this will soon be set by Tensor once we remove function.py
230
- if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
231
- # NOTE: this value is set by pickle when pickling a realized tensor
232
- if _buffer is not None:
233
- assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
234
- buffers[created] = _buffer
235
- return created
236
-
237
- # some uops map to other stuff
238
- buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
239
- all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
240
-
241
- # NOTE: this should be frozen, but frozen is slower
242
- @dataclass(eq=False, slots=True)
243
- class UOp(MathTrait, metaclass=UOpMetaClass):
244
- op:Ops
245
- dtype:DType = dtypes.void
246
- src:tuple[UOp, ...] = tuple()
247
- arg:Any = None
248
- children:set[weakref.ref[UOp]] = field(default_factory=set)
249
- def __del__(self):
250
- if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
251
- if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
252
- for s in self.src: s.children.discard(ref)
253
- del UOpMetaClass.ucache[k]
254
- def __reduce__(self):
255
- args = [self.op, self.dtype, self.src, self.arg]
256
- if self.op is Ops.BUFFER and self.realized is not None and PICKLE_BUFFERS: args.append(self.realized)
257
- return UOp, tuple(args)
258
- def replace(self, **kwargs) -> UOp:
259
- new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg))
260
- assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
261
- if (self.op, self.dtype, self.src, self.arg) == new_args: return self
262
- return UOp(*new_args)
263
- @functools.cached_property
264
- def key(self) -> bytes:
265
- return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
266
- def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
267
- def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
268
-
269
- @property
270
- def toposort(self) -> dict[UOp, None]:
271
- def _toposort(u:UOp, cache:set[UOp]):
272
- if u in cache: return {}
273
- nodes: dict[UOp, None] = {}
274
- # NOTE: this is a lot faster than the comprehension in parents
275
- for parent in u.src: nodes.update(_toposort(parent, cache))
276
- nodes[u] = None
277
- cache.add(u)
278
- return nodes
279
- return _toposort(self, cache=set())
280
-
281
- @functools.cached_property
282
- def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
283
-
284
- # *** uop shape stuff ***
285
-
286
- @functools.cached_property
287
- def st(self) -> ShapeTracker|None:
288
- from tinygrad.shape.shapetracker import ShapeTracker
289
- if self.op is Ops.MULTI:
290
- return ShapeTracker.from_shape(
291
- tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
292
- if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
293
- if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
294
- # these ops define a ShapeTracker from the arg
295
- if self.op is Ops.VIEW: return self.arg
296
- if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
297
- # buffer ops return the ShapeTracker from sources
298
- if self.op in GroupOp.Buffer: return vsrc[0] if len(vsrc:=[x.st for x in self.src if x.op is Ops.VIEW]) != 0 else None
299
- if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
300
- assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
301
- if self.op in {Ops.BITCAST, Ops.BUFFER_VIEW}:
302
- shape = src_sts[0].shape
303
- if self.dtype.itemsize != (input_sz:=self.src[0].dtype.itemsize): shape = shape[:-1]+((shape[-1]*input_sz) // self.dtype.itemsize,)
304
- # only reduce ops are allowed to change shape, everything else derives shape from sources
305
- elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
306
- else: shape = src_sts[0].shape
307
- return ShapeTracker.from_shape(shape)
308
-
309
- @functools.cached_property
310
- def full_shape(self) -> tuple[sint, ...]:
311
- if self.op is Ops.VIEW: return self.shape
312
- # TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
313
- return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
314
- # TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
315
- and not (x.op is Ops.CONST and x.st is None)]))
316
- @property
317
- def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
318
- @property
319
- def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
320
-
321
- # *** uop evaluation ***
322
-
323
- def simplify(self):
324
- # late import!
325
- from tinygrad.codegen.symbolic import symbolic
326
- with Context(TRACK_MATCH_STATS=0):
327
- return graph_rewrite(self, symbolic)
328
- def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
329
- def _eval(self, dtype, expected_type:Type[T]) -> T:
330
- assert self.dtype in dtype, f"eval with wrong dtype {self}"
331
- vmin, vmax = (simple_self:=self.simplify())._min_max
332
- if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
333
- assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
334
- return vmin
335
- def __bool__(self): return self._eval((dtypes.bool,), bool)
336
- def __int__(self): return self._eval(dtypes.ints, int)
337
- def __float__(self): return self._eval(dtypes.floats, float)
338
- def substitute(self, dvars:dict[UOp, UOp]):
339
- with Context(TRACK_MATCH_STATS=0):
340
- return graph_rewrite(self, _substitute, dvars, bottom_up=True)
341
-
342
- # *** uop syntactic sugar ***
343
-
344
- @property
345
- def st_arg(self) -> ShapeTracker:
346
- assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
347
- return unwrap(self.st)
348
- @property
349
- def axis_arg(self) -> tuple[int, ...]:
350
- assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
351
- ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
352
- assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
353
- return ret
354
- def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
355
- def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
356
- def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
357
- def const_like(self, b:ConstLike):
358
- # constants can optionally have a DEVICE source
359
- if self._device is None: return UOp.const(self.dtype, b)
360
- if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
361
- return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
362
- def broadcast(self, count:int):
363
- assert self.dtype.count == 1
364
- if count == 1: return self
365
- return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
366
- def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
367
- def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
368
- def gep(self, i:Union[tuple[int, ...], int]):
369
- if isinstance(i, int):
370
- # NOTE: these are just shortcuts to not have to create and fold later
371
- if self.op is Ops.VECTORIZE: return self.src[i]
372
- if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
373
- if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
374
- i = (i,)
375
- if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
376
- return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
377
- def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
378
- def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
379
- def alu(self, arg, *src:UOp):
380
- out_dtype = (self, *src)[-1].dtype
381
- if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
382
- return UOp(arg, out_dtype, (self,)+src)
383
- @staticmethod
384
- def const(dtype:DType, b:ConstLike):
385
- if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
386
- if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
387
- return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
388
- def valid(self, st:ShapeTracker):
389
- assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
390
- from tinygrad.shape.shapetracker import ShapeTracker
391
- # NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
392
- unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
393
- return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
394
- @staticmethod
395
- def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
396
- def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
397
- axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
398
- return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
399
- def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
400
- new_shape = unwrap(self.st).reduce(axis)
401
-
402
- # TODO: can we split symbolic shape if the reduce axis is not symbolic?
403
- # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
404
- if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
405
- prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
406
- return self._reduce_op(op, axis)
407
-
408
- # if there are few globals, make some reduces into globals by splitting into two kernels
409
- # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
410
- # ~2**10 should be enough if GROUP is used
411
- # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
412
- # split is moved to the end to provide maximum locality for the second phase reduce.
413
- self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
414
- split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
415
- if self.shape[i] % x == 0 and self_real_strides[i] != 0]
416
- if not split_candidates: return self._reduce_op(op, axis)
417
- dim_to_split, divisor = split_candidates[0]
418
- splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
419
- splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
420
- if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
421
- return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
422
- def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
423
- def contiguous(self): return self.alu(Ops.CONTIGUOUS)
424
- def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
425
-
426
- # *** from MultiLazyBuffer ***
427
-
428
- def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
429
- parents = (self,)+more
430
- assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
431
- return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
432
-
433
- @property
434
- def bounds(self):
435
- if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
436
- return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
437
-
438
- @functools.cached_property
439
- def axis(self) -> Optional[int]:
440
- if self.op is Ops.MULTI: return self.arg[0]
441
- # NOTE: they all have to share an axis, we always choose [-1]
442
- if self.op in GroupOp.ALU: return axes[-1] if (axes := dedup([x.axis for x in self.src if x.axis is not None])) else None
443
- src_axis = self.src[0].axis
444
- if self.op is Ops.REDUCE_AXIS: return None if src_axis is not None and src_axis in self.arg[1] else src_axis
445
- if self.op is Ops.RESHAPE:
446
- if src_axis is None: return None
447
- arg_acc:list[sint] = list(itertools.accumulate(self.arg, operator.mul, initial=1))
448
- # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
449
- # TODO: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
450
- return len(arg_acc) - arg_acc[::-1].index(prod(self.src[0].shape[:src_axis])) - 1
451
- if self.op is Ops.PERMUTE: return self.arg.index(src_axis) if src_axis is not None else None
452
- return src_axis
453
-
454
- @property
455
- def real(self):
456
- assert self.op is Ops.MULTI
457
- return self.arg[1]
458
-
459
- @property
460
- def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
461
-
462
- def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
463
- if axis is None: lbs = [self] * len(devices)
464
- else:
465
- if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
466
- # NOTE: this works for both even shards and uneven shards
467
- sz = self.shape[axis] // len(devices)
468
- sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
469
- lbs = []
470
- for sz,off in zip(sizes, itertools.accumulate(sizes, initial=0)):
471
- lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
472
- sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
473
- return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
474
-
475
- # *** from LazyBuffer ***
476
-
477
- @staticmethod
478
- def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
479
- from tinygrad.shape.shapetracker import ShapeTracker
480
- # Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
481
- if op is Ops.CONST:
482
- assert isinstance(arg, get_args(ConstType)), f"trying to create CONST with {arg=}"
483
- return UOp.const(dtype, unwrap(arg)).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),),
484
- ShapeTracker.from_shape(())),)).reshape((1,)*len(shape)).expand(shape)
485
- # Tensor variable binding is BIND(VAR(VIEW(DEVICE)), CONST(VIEW(DEVICE)))
486
- if op is Ops.BIND:
487
- var, val = arg.unbind()
488
- return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
489
- # otherwise it's just a RESHAPE(BUFFER)
490
- if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
491
- return UOp.new_buffer(device, size, dtype).reshape(shape)
492
- def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
493
- # if it's a shrink, do the shrink before the copy with CONTIGUOUS
494
- if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
495
- # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
496
- ret = UOp(Ops.COPY, self.base.dtype, (UOp(Ops.DEVICE, arg=device), self.base), clone)
497
- op_arg = []
498
- mop = self
499
- while mop is not self.base:
500
- op_arg.append((mop.op, mop.arg))
501
- mop = mop.src[0]
502
- for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
503
- return ret
504
- def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
505
- @property
506
- def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
507
-
508
- # *** uop movement ops ***
509
-
510
- @property
511
- def base(self) -> UOp:
512
- if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
513
- return self
514
- def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
515
-
516
- def _mop(self, op:Ops, arg):
517
- ret = UOp(op, self.dtype, (self,), arg)
518
- if self.st == ret.st: return self # ignore NOOPs, also check ret.st
519
- return ret
520
-
521
- def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
522
- def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
523
- def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
524
- def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
525
- def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
526
- def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
527
-
528
- # *** uop UNIQUE ***
529
-
530
- # TODO: use this in Buffer
531
- unique_num = itertools.count(0)
532
- @staticmethod
533
- def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
534
-
535
- # *** uop Buffer stuff ***
536
-
537
- @staticmethod
538
- def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
539
- @property
540
- def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
541
- @functools.cached_property
542
- def _device(self) -> Optional[str|tuple[str, ...]]:
543
- if self.op is Ops.DEVICE: return self.arg
544
- if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
545
- return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
546
- @property
547
- def buf_uop(self) -> UOp:
548
- if self.base.op is Ops.BUFFER: return self.base
549
- assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
550
- return self.src[0].buf_uop
551
- @property
552
- def buffer(self) -> Buffer:
553
- if self is not self.base:
554
- assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
555
- return self.src[0].buffer
556
- assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
557
- if (cret:=buffers.get(self)) is not None: return cret
558
- from tinygrad.device import Buffer
559
- assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
560
- buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
561
- return ret
562
- @property
563
- def realized(self) -> Optional[Buffer]: return self.buffer if self.op is Ops.BUFFER else None
564
- @property
565
- def is_realized(self) -> bool:
566
- return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
567
-
568
- # *** uop Variable stuff ***
569
-
570
- @staticmethod
571
- def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
572
- assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
573
- return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
574
- @property
575
- def expr(self):
576
- assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
577
- return self.arg[0]
578
- def bind(self, val:int):
579
- assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
580
- assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
581
- return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
582
- def unbind(self) -> tuple[Variable, int]:
583
- assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
584
- return self.src[0], self.src[1].arg
585
- @property
586
- def val(self) -> int: return self.unbind()[1]
587
- def vars(self) -> set[UOp]:
588
- bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
589
- bound_var_base = set(x.src[0] for x in bound_vars)
590
- all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
591
- return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
592
- def variables(self) -> list[Variable]:
593
- st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
594
- return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
595
-
596
- # *** uop symbolic stuff ***
597
-
598
- def is_increasing(self:UOp) -> bool:
599
- # is f a monotonically increasing function regards its input
600
- if self.op in GroupOp.Irreducible: return True
601
- if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
602
- if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
603
- return False # False if not sure
604
- def const_factor(self) -> int:
605
- """largest known int that divides self"""
606
- if self.op is Ops.CONST: return self.arg
607
- if self.op is Ops.VCONST: return math.gcd(*self.arg)
608
- if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
609
- if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
610
- return 1
611
- def divides(self, v:int) -> UOp|None:
612
- if v==1: return self
613
- if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
614
- if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
615
- if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
616
- if self.op is Ops.MUL:
617
- if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
618
- if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
619
- return None # generic None if we aren't sure
620
- @property
621
- def vmin(self) -> ConstType: return self._min_max[0]
622
- @property
623
- def vmax(self) -> ConstType: return self._min_max[1]
624
- @functools.cached_property
625
- def _min_max(self) -> tuple[ConstType, ConstType]:
626
- if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
627
- (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
628
- if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
629
- if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
630
- # SHL/SHR on consts only
631
- if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
632
- if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
633
- if self.op is Ops.MOD and s1_vmin > 0:
634
- return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
635
- if self.op is Ops.IDIV:
636
- if s1_vmin == s1_vmax: # min/max are equal in a CONST
637
- if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
638
- if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
639
- # don't know exact bounds, but know the sign
640
- if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
641
- if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
642
- if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
643
- if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
644
- if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
645
- if self.dtype == dtypes.bool:
646
- if self.op is Ops.OR: return s0_vmin or s1_vmin, s0_vmax or s1_vmax
647
- if self.op is Ops.AND: return s0_vmin and s1_vmin, s0_vmax and s1_vmax
648
- # float has NAN issue and we use explicit NAN in transcendental
649
- if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
650
- # NOTE: returned UOp is assumed to be CONST
651
- if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
652
- if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
653
- if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
654
- if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
655
- # TODO: Ops.SPECIAL is Ops.DEFINE_VAR
656
- if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
657
- if self.op is Ops.CONST: return self.arg, self.arg
658
- if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
659
- return dtypes.min(self.dtype), dtypes.max(self.dtype)
660
-
661
- @functools.cached_property
662
- def _sym_fxn(self):
663
- sself = self.simplify()
664
- varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
665
- # TODO: sanitize varnames, or don't use naked eval while staying fast
666
- return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
667
-
668
- def sym_infer(self, var_vals:dict[UOp, int]):
669
- fxn, varnames = self._sym_fxn
670
- return fxn(**{k.arg[0]:v for k,v in var_vals.items() if k.arg[0] in varnames})
671
-
672
- def render(self, simplify=True) -> str:
673
- ret = graph_rewrite(self.simplify() if simplify else self, renderer)
674
- return ret.arg if ret.op is Ops.NOOP else str(ret)
675
-
676
- @dataclass(frozen=True)
677
- class KernelInfo:
678
- name: str = "test" # name of the kernel
679
- local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
680
- upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
681
- dont_use_locals: bool = False # don't use local indexing
682
-
683
- # ******** ops in python ********
684
-
685
- def safe_exp2(x):
686
- try: return 2 ** x
687
- except OverflowError: return math.inf
688
-
689
- def safe_pow(x, y):
690
- try: return math.nan if isinstance(p:=pow(x, y), complex) else p
691
- except ZeroDivisionError: return math.inf
692
- except ValueError: return math.inf if x > 0 else -math.inf
693
-
694
- python_alu: dict[Ops, Callable] = {
695
- Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
696
- Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
697
- Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
698
- Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
699
- Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
700
- Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
701
- Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
702
-
703
- def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
704
- if dtype.count > 1:
705
- return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
706
- alu = python_alu[op](*operands)
707
- return truncate.get(dtype, lambda x: x)(alu) if truncate_output else alu
708
-
709
- # ***** uop helpers *****
710
-
711
- def print_uops(uops:list[UOp]):
712
- for i,u in enumerate(uops):
713
- formatted_parents = [(uops.index(x) if x.op is not Ops.CONST else f"{x.arg}") if x in uops else "--" for x in u.src]
714
- print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):30s} " f"{str(formatted_parents):32s} {u.arg}")
715
-
716
- # ***** pattern matcher *****
717
-
718
- def get_location() -> tuple[str, int]:
719
- frm = sys._getframe(1)
720
- # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
721
- while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
722
- "symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
723
- "linearize.py"}:
724
- frm = frm.f_back
725
- return frm.f_code.co_filename, frm.f_lineno
726
- @functools.lru_cache(None)
727
- def lines(fn) -> list[str]:
728
- with open(fn) as f: return f.readlines()
729
-
730
- class UPat(MathTrait):
731
- __slots__ = ("op", "dtype", "arg", "name", "src")
732
- def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
733
- src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
734
- name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
735
- assert op is None or isinstance(op, Ops) or isinstance(op, tuple) or isinstance(op, set), "op must be Ops or tuple of Ops"
736
- self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
737
- self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
738
- self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
739
- self.src: Any = None
740
- assert self.name != "ctx", "UPat can't be named ctx"
741
-
742
- # try all permutations if it's a list
743
- if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
744
- # only one if it's a tuple
745
- elif isinstance(src, tuple): self.src = [src]
746
- # repeat if it's a UPat
747
- elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
748
-
749
- self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
750
- self.location = location or get_location()
751
-
752
- if custom_early_reject is not None: self.early_reject = custom_early_reject
753
- else:
754
- upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
755
- self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
756
-
757
- def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, self.allowed_len == -1, self.custom_early_reject)
758
-
759
- @staticmethod
760
- def any(*src): return UPatAny(src=src)
761
-
762
- @staticmethod
763
- @functools.lru_cache(None)
764
- def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
765
- @staticmethod
766
- @functools.lru_cache(None)
767
- def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
768
- @staticmethod
769
- def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
770
-
771
- # copied from UOp
772
- def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
773
- def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
774
- def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
775
- def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
776
- def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
777
- def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
778
- def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
779
- def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
780
-
781
- def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
782
- def alu(self, op:Ops, *src:UPat):
783
- asrc = (self,)+src
784
- return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
785
-
786
- def printable(self:UPat) -> str:
787
- try: return lines(self.location[0])[self.location[1]-1].strip()
788
- except FileNotFoundError: return "<missing>"
789
-
790
- def __repr__(self):
791
- def rep(x):
792
- form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
793
- return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
794
- set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
795
- return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
796
-
797
- def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
798
- if (self.op is not None and uop.op not in self.op) or \
799
- (self.name is not None and store.setdefault(self.name, uop) is not uop) or \
800
- (self.dtype is not None and uop.dtype not in self.dtype and uop.dtype.scalar() not in self.dtype) or \
801
- (self.arg is not None and self.arg != uop.arg) or \
802
- (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
803
- if self.src is None: return [store]
804
- res: list[dict[str, UOp]] = []
805
- for vp in self.src:
806
- stores, new_stores = [store.copy()], []
807
- for uu, vv in zip(uop.src, vp):
808
- for s in stores: new_stores.extend(vv.match(uu, s))
809
- stores, new_stores = new_stores, []
810
- res.extend(stores)
811
- return res
812
-
813
- class UPatAny(UPat):
814
- def match(self:UPat, uop:UOp, store:dict[str, UOp]) -> list[dict[str, UOp]]:
815
- matches = [x.match(uop, store.copy()) for x in self.src[0]]
816
- return flatten([x for x in matches if x is not None])
817
-
818
- def deconstruct_function(fxn:Callable) -> tuple:
819
- new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names}
820
- for co in fxn.__code__.co_consts:
821
- if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names})
822
- # NOTE: optional round trip through pickle!
823
- assert fxn.__closure__ is None, "closures are not supported in pattern matchers"
824
- ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__
825
- return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret
826
-
827
- class PatternMatcher:
828
- def __init__(self, patterns:list[tuple[UPat, Callable]]):
829
- self.patterns = patterns
830
- # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
831
- self.pdict: dict[Ops, list[tuple[UPat, Callable, set, bool]]] = {}
832
- # uop is required, arg is optional
833
- for p,fxn in self.patterns:
834
- assert p.op is not None
835
- tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
836
- real_fxn = types.FunctionType(*tuple_fxn)
837
- for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
838
-
839
- def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
840
-
841
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
842
- def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
843
-
844
- def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
845
- ler = {u.op for u in uop.src}
846
- for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
847
- if not early_reject.issubset(ler): continue
848
- for match in p.match(uop, {}):
849
- if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None: return ret
850
- return None
851
-
852
- # *** tracking pattern matcher ***
853
-
854
- TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
855
- match_stats:dict[UPat, list[Union[int, float]]] = dict()
856
- @dataclass(frozen=True)
857
- class TrackedGraphRewrite:
858
- loc: tuple[str, int] # location that called graph_rewrite
859
- sink: UOp # the sink input to graph_rewrite
860
- bottom_up: bool
861
- matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
862
- name: Optional[str] = None
863
- tracked_keys:list[Any] = []
864
- tracked_ctxs:list[list[TrackedGraphRewrite]] = []
865
- _name_cnt:dict[str, int] = {}
866
- def track_rewrites(named=False):
867
- def _decorator(func):
868
- def __wrapper(self, *args, **kwargs):
869
- if TRACK_MATCH_STATS >= 2:
870
- if named: _name_cnt[func.__name__] = _name_cnt.get(func.__name__, 0)+1
871
- tracked_keys.append(f"{func.__name__}_{_name_cnt[func.__name__]}" if named else self)
872
- tracked_ctxs.append([])
873
- return func(self, *args, **kwargs)
874
- return __wrapper
875
- return _decorator
876
-
877
- class TrackedPatternMatcher(PatternMatcher):
878
- def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
879
- ret = None
880
- ler = {u.op for u in uop.src}
881
- for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []):
882
- if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
883
- st = time.perf_counter()
884
- if not early_reject.issubset(ler):
885
- match_stats[p][2] += time.perf_counter()-st
886
- continue
887
- match_stats[p][1] += 1
888
- for match in p.match(uop, {}):
889
- if (ret:=(fxn(ctx=ctx, **match) if has_ctx else fxn(**match))) is not None:
890
- match_stats[p][0] += 1
891
- match_stats[p][3] += (et:=time.perf_counter()-st)
892
- if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
893
- if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
894
- return ret # NOTE: if it returns None, we keep trying to match
895
- match_stats[p][2] += time.perf_counter()-st
896
- return None
897
-
898
- if TRACK_MATCH_STATS:
899
- PatternMatcher = TrackedPatternMatcher # type: ignore
900
- import atexit
901
- @atexit.register
902
- def print_match_stats():
903
- if TRACK_MATCH_STATS >= 2:
904
- with open(fn:=temp("rewrites.pkl", append_user=True), "wb") as f:
905
- print(f"rewrote {len(tracked_ctxs)} graphs and matched {sum(len(r.matches) for x in tracked_ctxs for r in x)} times, saved to {fn}")
906
- with Context(PICKLE_BUFFERS=0): pickle.dump((tracked_keys, tracked_ctxs), f)
907
- if getenv("VIZ"): launch_viz("VIZ", temp("rewrites.pkl", append_user=True))
908
- if getenv("PRINT_MATCH_STATS", 1):
909
- ret = [0,0,0.0,0.0]
910
- for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]+x[1][3]):
911
- loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
912
- if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {(v[2]+v[3])*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
913
- ret = [x+y for x,y in zip(ret, v)]
914
- print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {(ret[2]+ret[3])*1000.:9.2f} ms -- TOTAL")
915
-
916
- def launch_viz(env_str:str, data:str):
917
- os.environ[env_str] = "0"
918
- os.environ[f"{env_str}_DATA"] = data
919
- if not int(os.getenv("VIZ", "0")) and not int(os.getenv("PROFILE", "0")):
920
- args = ['--kernels', getenv("VIZ_DATA", "")] if getenv("VIZ_DATA", "") else []
921
- args += ['--profile', getenv("PROFILE_DATA", "")] if getenv("PROFILE_DATA", "") else []
922
- os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), ".", "viz", "serve.py")] + args)
923
-
924
- # *** simple graph rewrite engine ***
925
-
926
- class RewriteContext:
927
- def __init__(self, pm, ctx=None):
928
- self.pm: PatternMatcher = pm
929
- self.ctx = ctx
930
- self.replace: dict[UOp, UOp] = {}
931
- def top_down_rewrite(self, n:UOp) -> UOp:
932
- if (rn := self.replace.get(n)) is not None: return rn
933
- new_src = tuple([self.top_down_rewrite(x) for x in n.src])
934
- new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
935
- self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
936
- return ret
937
- def bottom_up_rewrite(self, n:UOp) -> UOp:
938
- if (rn := self.replace.get(n)) is not None: return rn
939
- new_n: UOp|None = n
940
- while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
941
- new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
942
- self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
943
- return ret
944
-
945
- def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
946
- if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
947
- tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
948
- return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
949
-
950
- def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
951
- if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
952
- tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
953
- rewrite_ctx = RewriteContext(pm, ctx)
954
- return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
955
-
956
- def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
957
-
958
- _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
959
-
960
- # for debug
961
- syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>",
962
- Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
963
- renderer = PatternMatcher([
964
- (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
965
- (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
966
- (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
967
- (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
968
- (UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
969
- (UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
970
- (UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
971
- (UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
972
- (UPat(GroupOp.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
973
- ])
974
-
975
- # *** what was symbolic.py ***
976
-
977
- sint = Union[int, UOp]
978
- Variable = UOp
979
-
980
- ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
981
-
982
- # *** UOp merge views and swizzling ***
983
-
984
- merge_views = PatternMatcher([
985
- # VIEW(VIEW) merges to a single VIEW
986
- (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
987
- # remove VIEW if it's contiguous and same as the base shape
988
- (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
989
- # merge unmasked const views
990
- (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
991
- lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
992
- ])
993
-
994
- # push VIEW to parents
995
- view_left = merge_views+PatternMatcher([
996
- # VIEW(CONST) becomes VALID
997
- (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
998
- # VIEW before elementwise/buffer ops
999
- (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
1000
- lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
1001
- (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
1002
- lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
1003
- ])