tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py CHANGED
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
- from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable
3
- import functools, itertools, heapq
2
+ from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
3
+ import functools, itertools, heapq, math
4
4
  from collections import defaultdict
5
5
  from enum import Enum, auto
6
- from dataclasses import dataclass
7
- from tinygrad.dtype import dtypes, DType
6
+ from dataclasses import dataclass, field
7
+ from tinygrad.dtype import ConstType, dtypes, DType
8
8
  from tinygrad.shape.symbolic import sint, Variable
9
9
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
10
10
  from tinygrad.helpers import prod, DEBUG, getenv
@@ -12,7 +12,7 @@ from tinygrad.helpers import prod, DEBUG, getenv
12
12
  # the order of these UOps controls the order of the toposort
13
13
  class UOps(Enum):
14
14
  # ops that aren't rendered
15
- SINK = auto()
15
+ SINK = auto(); VAR = auto() # noqa: E702
16
16
  DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
17
17
  CONST = auto(); SPECIAL = auto() # noqa: E702
18
18
  NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
@@ -26,89 +26,119 @@ class UOps(Enum):
26
26
  # these two are not graph nodes
27
27
  ENDRANGE = auto(); ENDIF = auto() # noqa: E702
28
28
 
29
- @dataclass(eq=False)
29
+ def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
30
+ @dataclass(frozen=True, eq=False)
30
31
  class UOp:
31
- uop: UOps
32
+ op: UOps
32
33
  dtype: Optional[DType] = None
33
- vin: Tuple[UOp, ...] = tuple()
34
+ src: Tuple[UOp, ...] = tuple()
34
35
  arg: Any = None
35
- def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
36
+ def commutative(self) -> bool:
37
+ return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
36
38
  @functools.cached_property
37
39
  def cmp_tuple(self):
38
40
  # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
39
- return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
40
- (type(self.uop), self.uop.value), self.dtype, self.vin)
41
+ return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
42
+ (type(self.op), self.op.value), self.dtype, self.src)
41
43
  def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
42
44
  def __repr__(self):
43
- return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
44
- def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
45
+ return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
46
+ def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
47
+ def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
45
48
  def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
46
- def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, x)
47
- def __sub__(self, x): return UOp.alu(BinaryOps.SUB, self, x)
48
- def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, x)
49
+ def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
50
+ def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
51
+ def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
52
+ def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
53
+ def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
54
+ def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
55
+ def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
56
+ def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
57
+ def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
58
+ def ge(self, x): return -self.lt(x)
59
+ def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
60
+ def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
49
61
  @staticmethod
50
- def max(x, y): return UOp.alu(BinaryOps.MAX, x, y)
62
+ @functools.lru_cache(maxsize=None)
63
+ def const(dtype:Optional[DType], b:ConstType|Variable):
64
+ if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
65
+ return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
51
66
  @staticmethod
52
- def min(x, y): return -UOp.alu(BinaryOps.MAX, -x, -y)
67
+ def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
53
68
  @staticmethod
54
- def const(dtype, val): return UOp(UOps.CONST, dtype, arg=dtypes.as_const(val, dtype))
69
+ def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
55
70
  @staticmethod
56
- def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else vin[-1].dtype, vin, arg)
71
+ def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
72
+ @staticmethod
73
+ def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
74
+ @staticmethod
75
+ def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
57
76
  @functools.cached_property
58
- def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
77
+ def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
78
+ @property # parents with self
79
+ def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
80
+ def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
59
81
 
60
82
  def uop_alu_resolve(u:UOp) -> sint:
61
- if u.uop is UOps.CONST: return u.arg
62
- if u.uop is UOps.DEFINE_VAR: return u.arg
63
- if u.uop is UOps.SPECIAL: return u.arg[2]-1
64
- if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
65
- if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
66
- raise RuntimeError(f"ALU resolve fail @ {u.uop}")
83
+ if u.op is UOps.CONST: return u.arg
84
+ if u.op is UOps.DEFINE_VAR: return u.arg
85
+ if u.op is UOps.SPECIAL: return u.arg[2]-1
86
+ if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
87
+ if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
88
+ if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
89
+ raise RuntimeError(f"ALU resolve fail @ {u.op}")
67
90
 
68
91
  # *** simplification logic ***
69
92
 
70
- def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
71
- for k,v in pattern.items():
72
- if k == "__name__":
73
- if v in store and store[v] != uop: return False
74
- store[v] = uop
75
- elif k == "arg":
76
- if uop.arg != v: return False
77
- elif k == "dtype":
78
- if isinstance(v, set):
79
- if uop.dtype not in v: return False
80
- elif uop.dtype != v: return False
81
- elif k == "uop":
82
- if isinstance(v, set):
83
- if uop.uop not in v: return False
84
- elif uop.uop != v: return False
85
- elif k == "vin":
86
- # only one if it's a tuple
87
- # try all permutations if it's a list
88
- # repeat if it's a dict
89
- for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
90
- if len(uop.vin) != len(vp) and (len(uop.vin) not in pattern.get('__allow_len__', [])): return False
91
- new_store = store.copy()
92
- if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
93
- for k,v in new_store.items(): store[k] = v
94
- return True
95
- return False
96
- return True
93
+ @dataclass(frozen=True)
94
+ class UPat:
95
+ op: Optional[Union[UOps, Set[UOps]]] = None
96
+ arg: Any = None
97
+ src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
98
+ name: Optional[str] = None
99
+ dtype: Optional[Union[DType, Set[DType]]] = None
100
+ allow_len: Set[int] = field(default_factory=set)
101
+
102
+ @staticmethod
103
+ def compile(u: UOp, name:Optional[str]=None) -> UPat:
104
+ if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
105
+ return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
106
+
107
+ T = TypeVar("T")
108
+ def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
109
+
110
+ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
111
+ if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
112
+ if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
113
+ if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
114
+ if pat.op is not None and __unmatch(pat.op, uop.op): return False
115
+ if pat.src is None: return True
116
+ # only one if it's a tuple
117
+ # try all permutations if it's a list
118
+ # repeat if it's a UPat
119
+ for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
120
+ if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
121
+ new_store = store.copy()
122
+ if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
123
+ store.update(new_store)
124
+ return True
125
+ return False
97
126
 
98
127
  class PatternMatcher:
99
- def __init__(self, patterns:List[Tuple[Dict[str, Any], Callable]]):
128
+ def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
100
129
  self.patterns = patterns
101
- self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[Dict[str, Any], Callable]]] = defaultdict(list)
130
+ self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
102
131
  # uop is required, arg is optional
103
132
  for p,fxn in self.patterns:
104
- uops = p["uop"]
105
- if isinstance(uops, set):
106
- for uop in uops: self.pdict[(uop, p.get("arg", None))].append((p, fxn))
133
+ if isinstance(p, UOp): p = UPat.compile(p)
134
+ assert p.op is not None
135
+ if isinstance(p.op, set):
136
+ for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
107
137
  else:
108
- self.pdict[(uops, p.get("arg", None))].append((p, fxn))
138
+ self.pdict[(p.op, p.arg)].append((p, fxn))
109
139
 
110
140
  def rewrite(self, uop:UOp) -> Optional[UOp]:
111
- for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
141
+ for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
112
142
  store: Dict[str, UOp] = {}
113
143
  if _match(uop, p, store): return fxn(**store)
114
144
  return None
@@ -116,130 +146,169 @@ class PatternMatcher:
116
146
  def sum_collapse(phi_input, loop, val1, val2):
117
147
  for v1,v2 in [(val1, val2), (val2, val1)]:
118
148
  if loop not in v1.parents:
119
- loop_range = loop.vin[1]-loop.vin[0]
149
+ loop_range = loop.src[1]-loop.src[0]
120
150
  ret = v1*loop_range.cast(v1.dtype)
121
151
  return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
122
152
  return None
123
153
 
124
- def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
154
+ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
155
+ if not rng.arg[2]: return None # must be a reduce
125
156
  if mval.arg >= 0 or loop_start.arg != 0:
126
157
  # TODO: support and test this with other mvals and loop_starts
127
158
  if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
128
159
  return None
129
- comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.DIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
160
+ comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
130
161
  return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
131
162
 
132
163
  # this is symbolic 2.0
133
164
  constant_folder = PatternMatcher([
134
165
  # arange loop folding (early)
135
- ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": (
136
- {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin":
137
- [{"__name__": "idx"}, {"uop": UOps.ALU, "arg": BinaryOps.MUL,
138
- "vin": [{"__name__": "mval", "uop": UOps.CONST}, {"uop": UOps.RANGE, "vin": ({"__name__": "loop_start"}, {"__name__": "loop_end"})}]}]},
139
- {"__name__": "compval", "uop": UOps.CONST})}, {"__name__": "multconst", "uop": UOps.CONST}, {"uop": UOps.CONST, "arg": 0})}, loop_collapse),
166
+ (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
167
+ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
168
+ UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
169
+ UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
170
+ (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
171
+ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
172
+ UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
173
+ UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
174
+ lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
140
175
  # sum collapse to mul (with possible GEP)
141
- ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.DEFINE_ACC, "vin": ({"uop": UOps.RANGE, "__name__": "loop"},)},
142
- {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
143
- ({"uop": UOps.PHI, "vin": ({"__name__": "phi_input", "uop": UOps.GEP,
144
- "vin": ({"uop": UOps.DEFINE_ACC, "vin":({"uop": UOps.RANGE, "__name__": "loop"},)},)},
145
- {"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "val1"}, {"__name__": "val2"})})}, sum_collapse),
176
+ (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
177
+ UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
178
+ (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
179
+ UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
146
180
  # deal with UNMUL
147
- ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"uop": UOps.CONST, "__name__": "c1"},
148
- {"uop": UOps.UNMUL, "vin": [{"uop": UOps.CONST, "__name__": "c2"}, {"__name__": "v"}]}]},
149
- lambda c1,c2,v: v if c1.arg == c2.arg else None),
150
- ({"uop": UOps.UNMUL, "vin": ({"uop": UOps.CONST, "__name__": "zero", "arg": 0}, {})}, lambda zero: zero),
151
- ({"__name__": "root", "uop": UOps.CAST, "vin": ({"uop": UOps.UNMUL, "__name__": "unmul"},)},
152
- lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
181
+ (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
182
+ lambda c1,c2,v: v if c1.arg == c2.arg else None),
183
+ (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
184
+ (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
153
185
  # max on special can go away (TODO: special should be variable, same thing applies)
154
- ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "vin": [{"__name__": "c", "uop": UOps.CONST}, {"__name__": "s", "uop": UOps.SPECIAL}]},
155
- lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
186
+ (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
156
187
  # const rules
157
- ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
158
- ({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root, c: UOp.const(root.dtype, c.arg)),
188
+ (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
189
+ (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
159
190
  # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
160
- ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "__name__": "acc"}, {"__name__": "acc"})}, lambda acc: UOp.const(acc.dtype, acc.arg[0])),
161
- ({"uop": UOps.PHI, "vin": ({"uop": UOps.DEFINE_ACC, "vin": tuple()}, {"__name__": "x"})}, lambda x: x),
162
- ({"uop": UOps.PHI, "vin": ({"uop": UOps.CONST}, {"__name__": "x"})}, lambda x: x),
191
+ (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
192
+ (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
193
+ (UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
163
194
  # a DEFINE_ACC without inputs is a const + GEP on a const is the const
164
- ({"__name__": "root", "uop": UOps.DEFINE_ACC, "vin": tuple()}, lambda root: UOp.const(root.dtype, root.arg[0])),
165
- ({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "x", "uop": UOps.CONST},)}, lambda root,x: UOp.const(root.dtype, x.arg)),
195
+ (UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
196
+ (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
166
197
  # max -2147483648
167
- ({"uop": UOps.ALU, "arg": BinaryOps.MAX, "dtype": dtypes.int, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -2147483648}]}, lambda x: x),
168
- # -(-x) -> x
169
- ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)})}, lambda x: x),
170
- # x+-y -> x-y
171
- ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
172
- lambda x, my: x-my.vin[0]),
173
- # -1*x -> -x
174
- ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": -1}]}, lambda x: -x),
198
+ (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
175
199
  # bool < False is always false, True < bool is always false
176
- ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({}, {"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": False})}, lambda x: x),
177
- ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "uop": UOps.CONST, "dtype": dtypes.bool, "arg": True}, {})},
178
- lambda x: UOp.const(dtypes.bool, False)),
200
+ (UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
201
+ (UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
179
202
  # a conditional with the same results either way is a noop, also fold const conditionals
180
- ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
181
- ({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
182
- lambda gate, c0, c1: c0 if gate.arg else c1),
203
+ (UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
204
+ (UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
183
205
  # ** constant folding **
184
- ({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
185
- lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
206
+ (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
186
207
  # ** self folding **
187
- ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
188
- ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
189
- ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
190
- ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
191
- ({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": -1})}, lambda x: -x), # x/-1 -> -x
208
+ (-(-UOp.var('x')), lambda x: x), # -(-x) -> x
209
+ (UOp.var('x') + 0, lambda x: x), # x+0 -> x
210
+ (UOp.var('x') - 0, lambda x: x), # x-0 -> x
211
+ (UOp.var('x') * 1, lambda x: x), # x*1 -> x
212
+ (UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
213
+ (UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
214
+ (UOp.var('x') // 1, lambda x: x), # x//1 -> x
215
+ (UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
216
+ (UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
217
+ (UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
218
+ (UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
192
219
  # ** zero folding **
193
- ({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
194
- ({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
220
+ #x*0 -> 0 or 0*x -> 0
221
+ #if x is nan or inf it should render the nan value.
222
+ # NOTE: this can be wrong for loaded NaN
223
+ (UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
224
+ (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
195
225
  # ** load/store folding **
196
- ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"},
197
- {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})}, lambda buf, idx: UOp(UOps.NOOP)),
226
+ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
198
227
  # ** two stage add/sub folding **
199
- ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.ADD,
200
- "vin": [{"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST}]}, {"__name__": "c2", "uop": UOps.CONST}]},
201
- lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
202
- ({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"uop": UOps.ALU, "arg": BinaryOps.SUB,
203
- "vin": ({"__name__": "x"}, {"__name__": "c1", "uop": UOps.CONST})}, {"__name__": "c2", "uop": UOps.CONST}]},
204
- lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c2.arg, c1.arg]))),
228
+ ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
229
+ ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
230
+ # *** rules from symbolic ***
231
+ # two stage mul, (x*c1)*c2 = x*(c1*c2)
232
+ ((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
233
+ # x%1 -> 0
234
+ (UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
235
+ # (x*c0)+(x*c1) -> x*(c0+c1)
236
+ (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
237
+ # (x*c0)+(y*c0) -> (x+y)*c0
238
+ #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
239
+ # (x*c0)//c0 -> x
240
+ ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
241
+ # (x*x2)/x2 -> x
242
+ ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
243
+ # (x//c0)//c1 -> x//(c0*c1)
244
+ ((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
245
+ # (x/x1)/x2 -> x/(x1*x2)
246
+ ((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
247
+ # c0 + x < c1 -> x < c1 - c0
248
+ ((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
249
+ lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
250
+ # (x+x*c0)-> x*(c0+1)
251
+ (UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
205
252
  # TODO: can do the invert of this (flip alt/load) when we fix double ops
206
- ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.ALU, "arg": TernaryOps.WHERE,
207
- "vin": ({"__name__": "gate"}, {"__name__": "alt"}, {"uop": UOps.LOAD, "vin": ({"__name__": "buf"}, {"__name__": "idx"})})})},
208
- lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
253
+ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
254
+ lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
209
255
  # store float4/float2 directly (remove CAST/GEP)
210
- ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
211
- tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(4))})},
212
- lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
213
- ({"uop": UOps.STORE, "vin": ({"__name__": "buf"}, {"__name__": "idx"}, {"uop": UOps.CAST, "vin":
214
- tuple({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i} for i in range(2))})},
215
- lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
256
+ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
257
+ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
216
258
  # CAST-PHI-GEP -> PHI-CAST
217
- ({"__name__": "root", "uop": UOps.CAST, "vin":
218
- tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(4))},
259
+ (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
219
260
  lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
220
- ({"__name__": "root", "uop": UOps.CAST, "vin":
221
- tuple({"uop": UOps.PHI, "vin": ({"uop": UOps.GEP, "vin": ({"__name__": "val"},), "arg": i}, {"__name__": f"v{i}"})} for i in range(2))},
261
+ (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
222
262
  lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
223
263
  # NEG/CMPLT -> CMPLT
224
- ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)},
225
- {"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})},
226
- lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
264
+ (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
227
265
  # cast NOOP (NOTE: it's str to deal with PtrDType)
228
- ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
266
+ (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
267
+ # fold gated LOAD/STORE
268
+ (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
269
+ (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
270
+ lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
271
+ (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
272
+ (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
273
+ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
274
+ (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
275
+ # remove NOOPs from SINK
276
+ (UPat(UOps.SINK, name="root"),
277
+ lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
229
278
  ])
230
279
 
231
280
  # *** uop graph ***
232
281
 
282
+ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
283
+ if u in children: return
284
+ children[u] = []
285
+ for x in u.src:
286
+ get_children_dfs(x, children, in_degree)
287
+ children[x].append(u)
288
+ in_degree[u] = len(u.src)
289
+
290
+ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
291
+ nodes: Dict[Tuple, UOp] = {}
292
+ replace: Dict[UOp, UOp] = {}
293
+ def __inner_rewrite(n:UOp) -> UOp:
294
+ if n in replace: return replace[n]
295
+ replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
296
+ if found := nodes.get(replace_source): replace[n] = found
297
+ else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
298
+ return replace[n]
299
+ return __inner_rewrite(sink)
300
+
233
301
  class UOpGraph:
234
- def __init__(self):
235
- self.nodes: Dict[Tuple, UOp] = {}
302
+ def __init__(self, sinks:List[UOp]):
303
+ self.sinks: List[UOp] = sinks
304
+ # used by linearizer
236
305
  self._uops: Optional[List[UOp]] = None
237
306
 
238
307
  def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
239
308
  def __getitem__(self, index) -> UOp: return self.uops[index]
240
309
 
241
- def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
242
- def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
310
+ def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
311
+ def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
243
312
 
244
313
  @property
245
314
  def uops(self):
@@ -252,164 +321,131 @@ class UOpGraph:
252
321
 
253
322
  def print(self):
254
323
  for i,u in enumerate(self):
255
- print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
256
-
257
- def graph_rewrite(self, sink, pm):
258
- # recursive rewrite
259
- changed = getenv("UOPS_REWRITE", 1)
260
- run_cnt = 0
261
- while changed:
262
- changed = 0
263
- @functools.lru_cache
264
- def rewrite(u:UOp) -> UOp:
265
- nonlocal changed
266
- recurse_cnt = 0
267
- up = u
268
- # locally recursively rewrite
269
- while (rewritten := pm.rewrite(up)):
270
- assert recurse_cnt < 100, f"recursive_rewrite looped {up} <--> {rewritten}"
271
- up = rewritten
272
- recurse_cnt += 1
273
- changed += recurse_cnt
274
- # NOTE: this changes UOp, so we have to delete caches
275
- up.vin = tuple(rewrite(x) for x in up.vin)
276
- if hasattr(up, "parents"): del up.parents
277
- if hasattr(up, "cmp_tuple"): del up.cmp_tuple
278
- # replace with cached nodes
279
- if found:=self.nodes.get(key:=up.tuple()): return found
280
- else: self.nodes[key] = up
281
- return up
282
- sink = rewrite(sink)
283
- run_cnt += 1
284
- assert run_cnt < 100, "exceeded 100 rewrite loops!"
285
- return sink
324
+ print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
286
325
 
287
326
  def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
288
327
  # NOTE: relinearizering should be okay
289
328
  #assert self._uops is None, "already linearized"
329
+ sink = UOp(UOps.SINK, None, tuple(self.sinks))
290
330
 
291
- # get sink
292
- _sinks: List[UOp] = []
293
- for u in self.nodes.values():
294
- if u.uop is UOps.STORE: _sinks.append(u)
295
- if u.uop is UOps.SINK: _sinks.extend(u.vin)
296
- sink = UOp(UOps.SINK, None, tuple(_sinks))
297
- del _sinks
298
-
299
- sink = self.graph_rewrite(sink, constant_folder)
300
- if extra_pm: sink = self.graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
331
+ # dedup all nodes and do graph rewrite
332
+ sink = graph_rewrite(sink, constant_folder)
333
+ if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
301
334
 
302
335
  # filter nodes that don't link to a sink
303
336
  # BFS toposort
304
- graph: DefaultDict[UOp, List[UOp]] = defaultdict(list)
305
- in_degree: DefaultDict[UOp, int] = defaultdict(int)
306
- loops = []
307
- ifs = []
308
- nodes: Dict[UOp, None] = {}
309
- def add_parents(u:UOp):
310
- if u in nodes: return
311
- nodes[u] = None
312
- for x in u.vin:
313
- add_parents(x)
314
- in_degree[u] += 1
315
- graph[x].append(u)
316
- if u.uop is UOps.RANGE: loops.append(u)
317
- if u.uop is UOps.IF: ifs.append(u)
318
- sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
319
- add_parents(sink)
337
+ children: Dict[UOp, List[UOp]] = {}
338
+ in_degree: Dict[UOp, int] = {}
339
+ get_children_dfs(sink, children, in_degree)
320
340
 
321
341
  @functools.lru_cache(None)
322
- def get_recursive_children(x:UOp, include_self=False) -> Set[UOp]:
323
- if x.uop is UOps.SINK: return set()
324
- return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, True) for u in graph[x]] if x.uop is not UOps.PHI else []))
325
- loops_children = {l:get_recursive_children(l) for l in loops[::-1]}
342
+ def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
343
+ if x.op is UOps.SINK: return set()
344
+ return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
345
+
346
+ # scope children impact the toposort and END* insertion
347
+ end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
348
+ loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
349
+ scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
326
350
 
327
- queue: List = []
328
- def push(u):
351
+ queue:List[Tuple[int, UOp]] = []
352
+ def push(u:UOp):
329
353
  priority = 0
330
354
  # prefer uops that are loop children
331
- for l, ss in loops_children.items():
332
- if u in ss: priority -= l.arg[0]*1000 + l.arg[1]
355
+ for l, ss in scope_children.items():
356
+ if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
333
357
  heapq.heappush(queue, (priority, u))
334
358
 
335
- for u in nodes:
359
+ for u in children:
336
360
  if in_degree[u] == 0: push(u)
337
361
 
338
362
  if getenv("FUZZ_UOPS", 0):
339
363
  from test.external.fuzz_uops import fuzz_uops
340
- self.fuzz_paths = fuzz_uops(graph, in_degree.copy(), loops_children)
364
+ self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
341
365
 
342
366
  self._uops = []
343
367
  while queue:
344
368
  p,x = heapq.heappop(queue)
345
369
  if DEBUG >= 7: print(p,x)
346
- if x.uop is UOps.DEFINE_ACC and len(x.vin):
347
- idx = min([self._uops.index(l) for l in x.vin])
370
+ if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
371
+ idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
348
372
  self._uops.insert(idx, x)
349
373
  else:
350
374
  self._uops.append(x)
351
- for u, ss in loops_children.items():
375
+ for u, ss in scope_children.items():
352
376
  if x in ss:
353
377
  ss.remove(x)
354
- if len(ss) == 0: self._uops.append(UOp(UOps.ENDRANGE, None, (u,)))
355
- for u in graph[x]:
378
+ if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
379
+ for u in children[x]:
356
380
  in_degree[u] -= 1
357
381
  if in_degree[u] == 0: push(u)
358
382
 
359
- assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
383
+ assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
360
384
  self._uops = self._uops[:-1]
361
385
 
362
- # TODO: ifs should be removed and just the store should be gated
363
- for u in ifs[::-1]: self._uops.append(UOp(UOps.ENDIF, None, (u,)))
364
-
365
386
  if type_verify: self.type_verify()
366
387
 
367
- def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
368
- if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
369
- self.nodes[key] = ret = UOp(*key)
370
- return ret
371
-
372
388
  # *** checker functions ***
373
389
 
374
- def flops_mem(self) -> Tuple[sint, sint]:
390
+ def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
375
391
  flops: sint = 0
376
392
  mem: sint = 0
377
393
  mults: sint = 1
378
394
  mult_stack = []
395
+ dont_count: Set[UOp] = set()
396
+ if ignore_indexing:
397
+ for u in self.uops:
398
+ if u.op is UOps.LOAD:
399
+ dont_count = dont_count.union(u.src[1].sparents)
400
+ if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
401
+ elif u.op is UOps.STORE:
402
+ dont_count = dont_count.union(u.src[1].sparents)
403
+ if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
379
404
  for u in self.uops:
380
- if u.uop is UOps.RANGE:
405
+ if u.op is UOps.RANGE:
381
406
  mult_stack.append(mults)
382
- mults *= uop_alu_resolve(u.vin[1])
383
- elif u.uop is UOps.ENDRANGE:
407
+ mults *= uop_alu_resolve(u.src[1])
408
+ elif u.op is UOps.ENDRANGE:
384
409
  mults = mult_stack.pop(-1)
385
- elif u.uop is UOps.ALU:
386
- flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
387
- elif u.uop is UOps.LOAD:
410
+ elif u.op is UOps.LOAD:
388
411
  assert u.dtype is not None
389
412
  mem += u.dtype.itemsize * mults
390
- elif u.uop is UOps.STORE:
391
- assert u.vin[2].dtype is not None
392
- mem += u.vin[2].dtype.itemsize * mults
393
- elif u.uop is UOps.WMMA:
413
+ elif u.op is UOps.STORE:
414
+ assert u.src[2].dtype is not None
415
+ mem += u.src[2].dtype.itemsize * mults
416
+ elif u.op is UOps.ALU and u not in dont_count:
417
+ flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
418
+ elif u.op is UOps.WMMA and u not in dont_count:
394
419
  assert u.arg[1] is not None
395
420
  flops += 2 * prod(u.arg[1]) // 32 * mults
396
421
  return flops, mem
397
422
 
398
423
  def type_verify(self):
399
424
  for u in self.uops:
400
- uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
401
- if uop in {UOps.CONST, UOps.DEFINE_ACC}:
402
- if uop is UOps.DEFINE_ACC: arg = arg[0]
425
+ uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
426
+ if uop in (UOps.CONST, UOps.DEFINE_ACC):
427
+ if uop is UOps.DEFINE_ACC:
428
+ assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
429
+ arg = src[0].arg
403
430
  assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
404
431
  if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
432
+ if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
433
+ if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
405
434
  if uop is UOps.ALU:
406
435
  if arg in UnaryOps:
407
- assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}"
408
- elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ):
436
+ assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
437
+ elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
409
438
  assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
410
- assert vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
439
+ assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
440
+ elif arg is BinaryOps.IDIV:
441
+ assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
442
+ f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
443
+ assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
444
+ elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
445
+ # the distance to shift isn't typechecked
446
+ assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
411
447
  elif arg in BinaryOps:
412
- assert dtype == vin[0].dtype == vin[1].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=} != {vin[1].dtype=}"
448
+ assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
413
449
  elif arg == TernaryOps.WHERE:
414
- assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}"
415
- assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}"
450
+ assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
451
+ assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"