tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/codegen/uops.py CHANGED
@@ -1,24 +1,24 @@
1
1
  from __future__ import annotations
2
- from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
3
- import functools, itertools, heapq, math
2
+ from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
3
+ import functools, itertools, math
4
4
  from collections import defaultdict
5
5
  from enum import Enum, auto
6
- from dataclasses import dataclass, field
6
+ from dataclasses import dataclass
7
7
  from tinygrad.dtype import ConstType, dtypes, DType
8
8
  from tinygrad.shape.symbolic import sint, Variable
9
9
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
10
- from tinygrad.helpers import prod, DEBUG, getenv
10
+ from tinygrad.helpers import merge_dicts, prod, pretty_print
11
11
 
12
12
  # the order of these UOps controls the order of the toposort
13
13
  class UOps(Enum):
14
14
  # ops that aren't rendered
15
- SINK = auto(); VAR = auto() # noqa: E702
15
+ SINK = auto(); EXPAND = auto(); CONTRACT = auto() # noqa: E702
16
16
  DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
17
17
  CONST = auto(); SPECIAL = auto() # noqa: E702
18
- NOOP = auto(); UNMUL = auto(); GEP = auto() # noqa: E702
18
+ NOOP = auto(); GEP = auto() # noqa: E702
19
19
  # math ops
20
- CAST = auto(); BITCAST = auto() # noqa: E702
21
- ALU = auto(); WMMA = auto() # noqa: E702
20
+ CAST = auto(); BITCAST = auto(); VECTORIZE = auto() # noqa: E702
21
+ ALU = auto(); REDUCE = auto(); WMMA = auto() # noqa: E702
22
22
  # memory/assignment ops
23
23
  LOAD = auto(); STORE = auto(); PHI = auto() # noqa: E702
24
24
  # control flow ops
@@ -26,7 +26,8 @@ class UOps(Enum):
26
26
  # these two are not graph nodes
27
27
  ENDRANGE = auto(); ENDIF = auto() # noqa: E702
28
28
 
29
- def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
29
+ END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
30
+
30
31
  @dataclass(frozen=True, eq=False)
31
32
  class UOp:
32
33
  op: UOps
@@ -34,418 +35,259 @@ class UOp:
34
35
  src: Tuple[UOp, ...] = tuple()
35
36
  arg: Any = None
36
37
  def commutative(self) -> bool:
37
- return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
38
+ return (self.op is UOps.ALU and \
39
+ self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
38
40
  @functools.cached_property
39
41
  def cmp_tuple(self):
40
42
  # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
41
43
  return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
42
- (type(self.op), self.op.value), self.dtype, self.src)
44
+ self.arg.value, self.dtype, self.src)
43
45
  def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
44
- def __repr__(self):
45
- return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
46
- def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
47
- def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
48
- def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
49
- def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
50
- def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
51
- def __sub__(self, x): return UOp.alu(BinaryOps.ADD, self, -ufix(self.dtype, x))
52
- def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
53
- def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
54
- def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
55
- def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
56
- def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
57
- def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
58
- def ge(self, x): return -self.lt(x)
59
- def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
60
- def min(self, x): return -UOp.alu(BinaryOps.MAX, -self, -x)
46
+ def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
47
+ # *** uop syntactic sugar
48
+ def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
49
+ def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
50
+ def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
51
+ def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar() if self.dtype is not None else None, (self,), i)
52
+ def __neg__(self): return self.alu(UnaryOps.NEG)
53
+ def __add__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
54
+ def __radd__(self, x): return self.alu(BinaryOps.ADD, self.ufix(x))
55
+ def __sub__(self, x): return self.alu(BinaryOps.ADD, self.ufix(-x))
56
+ def __mul__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x))
57
+ def __rmul__(self, x): return self.ufix(x).alu(BinaryOps.MUL, self)
58
+ def __floordiv__(self, x): return self.alu(BinaryOps.IDIV, self.ufix(x))
59
+ def __truediv__(self, x): return self.alu(BinaryOps.MUL, self.ufix(x).alu(UnaryOps.RECIP))
60
+ def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x))
61
+ def __xor__(self, x): return self.alu(BinaryOps.XOR, self.ufix(x))
62
+ def __and__(self, x): return self.alu(BinaryOps.AND, self.ufix(x))
63
+ def __or__(self, x): return self.alu(BinaryOps.OR, self.ufix(x))
64
+ def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x))
65
+ def eq(self, x): return -self.ne(x)
66
+ def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
67
+ def ge(self, x): return (-self).lt(-x+1)
68
+ def max(self, x): return self.alu(BinaryOps.MAX, x)
69
+ def min(self, x): return -(-self).max(-x)
70
+ def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
71
+ def recip(self): return self.alu(UnaryOps.RECIP)
72
+ def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
73
+ def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
74
+ return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
61
75
  @staticmethod
62
76
  @functools.lru_cache(maxsize=None)
63
- def const(dtype:Optional[DType], b:ConstType|Variable):
64
- if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b)
77
+ def _const(dtype:Optional[DType], b:ConstType|Variable):
78
+ # TODO: fix dtype of b.max after Variable is just an UOp
79
+ if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
80
+ if dtype is not None and dtype != (sdtype := dtype.scalar()):
81
+ return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
65
82
  return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
83
+ def alu(self, arg, *src:UOp):
84
+ return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
66
85
  @staticmethod
67
- def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg)
68
- @staticmethod
69
- def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
86
+ def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return type(src[0])(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values()))
70
87
  @staticmethod
71
- def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values()))
72
- @staticmethod
73
- def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
74
- @staticmethod
75
- def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
88
+ def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
76
89
  @functools.cached_property
77
- def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
90
+ def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
78
91
  @property # parents with self
79
- def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
80
- def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
81
-
82
- def uop_alu_resolve(u:UOp) -> sint:
83
- if u.op is UOps.CONST: return u.arg
84
- if u.op is UOps.DEFINE_VAR: return u.arg
85
- if u.op is UOps.SPECIAL: return u.arg[2]-1
86
- if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
87
- if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
88
- if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
89
- raise RuntimeError(f"ALU resolve fail @ {u.op}")
92
+ def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
93
+ def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
94
+ def const_factor(self) -> int:
95
+ """largest known int that divides self"""
96
+ if self.op is UOps.CONST: return self.arg
97
+ if self.op is UOps.ALU:
98
+ if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[0].const_factor())
99
+ if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
100
+ return 1
101
+ def divides(self, v) -> Optional[UOp]:
102
+ if v==1: return self
103
+ if self.op is UOps.CONST: return self.const(self.arg//v) if self.arg%v == 0 else None
104
+ if self.op is UOps.ALU:
105
+ if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
106
+ if self.arg is BinaryOps.MUL:
107
+ if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
108
+ if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
109
+ return None # generic None if we aren't sure
110
+ @functools.cached_property
111
+ def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
112
+ @functools.cached_property
113
+ def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
114
+ @functools.cached_property
115
+ def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
116
+ # NOTE: returned UOp is assumed to be CONST
117
+ if self.op is UOps.DEFINE_VAR: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
118
+ if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
119
+ # TODO: UOps.SPECIAL is UOps.DEFINE_VAR
120
+ if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
121
+ if self.op is UOps.CONST: return self, self
122
+ if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
123
+ s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
124
+ if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
125
+ return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
126
+ if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
127
+ if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
128
+ # handle at lease one is non-negative
129
+ Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
130
+ Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
131
+ assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
132
+ return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
133
+ if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
134
+ if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
135
+ if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
136
+ if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
137
+ if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
138
+ if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
139
+ (UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
140
+ return None, None
141
+
142
+ @dataclass(frozen=True, repr=False) # reuse repr from UOp
143
+ class NOp(UOp):
144
+ name:Optional[str] = None
145
+ src:Tuple[NOp, ...] = tuple()
146
+ allow_any_len:bool = False
147
+ @staticmethod
148
+ def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, name=name)
149
+ @staticmethod
150
+ def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.CONST, dtype=dtype, name=name)
151
+ def const(self:Union[UOp, DType, None], b:ConstType|Variable): return NOp((x:=UOp.const(self, b)).op, x.dtype, x.src, x.arg)
90
152
 
91
- # *** simplification logic ***
153
+ def compile(self: NOp, name:Optional[str]=None) -> UPat:
154
+ return UPat(name=self.name, dtype=self.dtype) if self.op is UOps.NOOP else UPat(self.op, self.arg, (list if self.commutative()
155
+ else tuple)(src.compile() for src in self.src) or None, self.name or name, self.dtype, self.allow_any_len)
92
156
 
93
- @dataclass(frozen=True)
94
157
  class UPat:
95
- op: Optional[Union[UOps, Set[UOps]]] = None
96
- arg: Any = None
97
- src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
98
- name: Optional[str] = None
99
- dtype: Optional[Union[DType, Set[DType]]] = None
100
- allow_len: Set[int] = field(default_factory=set)
101
-
102
- @staticmethod
103
- def compile(u: UOp, name:Optional[str]=None) -> UPat:
104
- if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
105
- return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
106
-
107
- T = TypeVar("T")
108
- def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
158
+ def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,
159
+ name:Optional[str]=None, dtype:Optional[Union[DType, Set[DType]]]=None, allow_any_len:bool=False):
160
+ self.op: Optional[Tuple[UOps, ...]] = None if op is None else (tuple(op) if isinstance(op, set) else (op,))
161
+ self.dtype: Optional[Tuple[DType, ...]] = None if dtype is None else (tuple(dtype) if isinstance(dtype, set) else (dtype,))
162
+ self.arg, self.name = arg, name
163
+ self.src: Any = None
164
+ # try all permutations if it's a list
165
+ if isinstance(src, list): self.src = list(itertools.permutations(src))
166
+ # only one if it's a tuple
167
+ elif isinstance(src, tuple): self.src = [src]
168
+ # repeat if it's a UPat
169
+ elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
170
+
171
+ self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
109
172
 
110
- def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
111
- if pat.name is not None and store.setdefault(pat.name, uop) is not uop: return False
112
- if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
113
- if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
114
- if pat.op is not None and __unmatch(pat.op, uop.op): return False
115
- if pat.src is None: return True
116
- # only one if it's a tuple
117
- # try all permutations if it's a list
118
- # repeat if it's a UPat
119
- for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
120
- if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
121
- new_store = store.copy()
122
- if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
123
- store.update(new_store)
124
- return True
125
- return False
173
+ def __repr__(self):
174
+ def rep(x):
175
+ form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
176
+ return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
177
+ set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
178
+ return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
179
+
180
+ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:
181
+ if (pat.name is not None and store.setdefault(pat.name, uop) is not uop) or \
182
+ (pat.dtype is not None and uop.dtype not in pat.dtype) or \
183
+ (pat.arg is not None and pat.arg != uop.arg) or \
184
+ (pat.op is not None and uop.op not in pat.op): return []
185
+ if pat.src is None: return [store]
186
+ res: List[Dict[str, UOp]] = []
187
+ for vp in pat.src:
188
+ if pat.allowed_len != 0 and len(uop.src) != pat.allowed_len: return []
189
+ new_stores = [store.copy()]
190
+ for uu, vv in zip(uop.src, vp): new_stores = [rstore for nstore in new_stores for rstore in _match(uu, vv, nstore)]
191
+ res.extend(new_stores)
192
+ return res
126
193
 
127
194
  class PatternMatcher:
128
- def __init__(self, patterns:List[Tuple[Union[UPat, UOp], Callable]]):
195
+ def __init__(self, patterns:List[Tuple[Union[UPat, NOp], Callable]]):
129
196
  self.patterns = patterns
130
197
  self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable]]] = defaultdict(list)
131
198
  # uop is required, arg is optional
132
199
  for p,fxn in self.patterns:
133
- if isinstance(p, UOp): p = UPat.compile(p)
200
+ if isinstance(p, NOp): p = p.compile()
134
201
  assert p.op is not None
135
- if isinstance(p.op, set):
136
- for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
137
- else:
138
- self.pdict[(p.op, p.arg)].append((p, fxn))
202
+ for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
203
+
204
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
205
+ def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
139
206
 
140
207
  def rewrite(self, uop:UOp) -> Optional[UOp]:
141
208
  for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
142
- store: Dict[str, UOp] = {}
143
- if _match(uop, p, store): return fxn(**store)
209
+ if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
144
210
  return None
145
211
 
146
- def sum_collapse(phi_input, loop, val1, val2):
147
- for v1,v2 in [(val1, val2), (val2, val1)]:
148
- if loop not in v1.parents:
149
- loop_range = loop.src[1]-loop.src[0]
150
- ret = v1*loop_range.cast(v1.dtype)
151
- return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
152
- return None
153
-
154
- def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng):
155
- if not rng.arg[2]: return None # must be a reduce
156
- if mval.arg >= 0 or loop_start.arg != 0:
157
- # TODO: support and test this with other mvals and loop_starts
158
- if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
159
- return None
160
- comprange = UOp.min(loop_end, UOp.max(UOp.alu(BinaryOps.IDIV, idx-compval-mval, mval) + (loop_end-loop_start), loop_start))
161
- return UOp(UOps.UNMUL, multconst.dtype, (comprange.cast(multconst.dtype) * multconst, loop_end-loop_start))
162
-
163
- # this is symbolic 2.0
164
- constant_folder = PatternMatcher([
165
- # arange loop folding (early)
166
- (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
167
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, src=[UPat(UOps.CONST, name="mval"),
168
- UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
169
- UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
170
- (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
171
- UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, UnaryOps.NEG, src=[
172
- UPat(UOps.RANGE, name="rng", src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
173
- UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))),
174
- lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
175
- # sum collapse to mul (with possible GEP)
176
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
177
- UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
178
- (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
179
- UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
180
- # deal with UNMUL
181
- (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
182
- lambda c1,c2,v: v if c1.arg == c2.arg else None),
183
- (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
184
- (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
185
- # max on special can go away (TODO: special should be variable, same thing applies)
186
- (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
187
- # const rules
188
- (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
189
- (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
190
- # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
191
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
192
- (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
193
- (UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
194
- # a DEFINE_ACC without inputs is a const + GEP on a const is the const
195
- (UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
196
- (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
197
- # max -2147483648
198
- (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
199
- # bool < False is always false, True < bool is always false
200
- (UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)),
201
- (UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
202
- # a conditional with the same results either way is a noop, also fold const conditionals
203
- (UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
204
- (UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
205
- # ** constant folding **
206
- (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
207
- # ** self folding **
208
- (-(-UOp.var('x')), lambda x: x), # -(-x) -> x
209
- (UOp.var('x') + 0, lambda x: x), # x+0 -> x
210
- (UOp.var('x') - 0, lambda x: x), # x-0 -> x
211
- (UOp.var('x') * 1, lambda x: x), # x*1 -> x
212
- (UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
213
- (UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
214
- (UOp.var('x') // 1, lambda x: x), # x//1 -> x
215
- (UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
216
- (UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
217
- (UOp.var('x') / UOp.cvar('c'), lambda x,c: x*exec_alu(UnaryOps.RECIP, c.dtype, [c.arg])), # x/c -> x*(1/c)
218
- (UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
219
- # ** zero folding **
220
- #x*0 -> 0 or 0*x -> 0
221
- #if x is nan or inf it should render the nan value.
222
- # NOTE: this can be wrong for loaded NaN
223
- (UOp.var('x') * 0, lambda x: UOp.const(x.dtype, float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
224
- (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
225
- # ** load/store folding **
226
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
227
- # ** two stage add/sub folding **
228
- ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
229
- ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
230
- # *** rules from symbolic ***
231
- # two stage mul, (x*c1)*c2 = x*(c1*c2)
232
- ((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
233
- # x%1 -> 0
234
- (UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
235
- # (x*c0)+(x*c1) -> x*(c0+c1)
236
- (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
237
- # (x*c0)+(y*c0) -> (x+y)*c0
238
- #((UOp.var("x") * UOp.cvar("c0")) + (UOp.var("y") * UOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
239
- # (x*c0)//c0 -> x
240
- ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
241
- # (x*x2)/x2 -> x
242
- ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
243
- # (x//c0)//c1 -> x//(c0*c1)
244
- ((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
245
- # (x/x1)/x2 -> x/(x1*x2)
246
- ((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
247
- # c0 + x < c1 -> x < c1 - c0
248
- ((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
249
- lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
250
- # (x+x*c0)-> x*(c0+1)
251
- (UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
252
- # TODO: can do the invert of this (flip alt/load) when we fix double ops
253
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
254
- lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
255
- # store float4/float2 directly (remove CAST/GEP)
256
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
257
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
258
- # CAST-PHI-GEP -> PHI-CAST
259
- (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
260
- lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
261
- (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
262
- lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
263
- # NEG/CMPLT -> CMPLT
264
- (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
265
- # cast NOOP (NOTE: it's str to deal with PtrDType)
266
- (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
267
- # fold gated LOAD/STORE
268
- (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
269
- (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")),
270
- lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
271
- (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var),
272
- (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var),
273
- (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store),
274
- (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)),
275
- # remove NOOPs from SINK
276
- (UPat(UOps.SINK, name="root"),
277
- lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None)
278
- ])
279
-
280
- # *** uop graph ***
212
+ def type_verify(uops):
213
+ for u in uops:
214
+ uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
215
+ if uop in {UOps.CONST, UOps.DEFINE_ACC}:
216
+ if uop is UOps.CONST:
217
+ assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
218
+ assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
219
+ if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
220
+ if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
221
+ if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
222
+ if uop is UOps.VECTORIZE:
223
+ assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
224
+ assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
225
+ if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
226
+ if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
227
+ if uop is UOps.STORE:
228
+ assert dtype is None, f"{uop} dtype must be None, got {dtype}"
229
+ if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
230
+ if uop is UOps.ALU:
231
+ if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
232
+ elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
233
+ assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
234
+ assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
235
+ elif arg is BinaryOps.IDIV:
236
+ assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
237
+ assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
238
+ elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
239
+ # the distance to shift isn't typechecked
240
+ assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
241
+ elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
242
+ elif arg == TernaryOps.WHERE:
243
+ assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
244
+ f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
245
+ assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
281
246
 
282
- def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]):
283
- if u in children: return
284
- children[u] = []
285
- for x in u.src:
286
- get_children_dfs(x, children, in_degree)
287
- children[x].append(u)
288
- in_degree[u] = len(u.src)
289
-
290
- def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
291
- nodes: Dict[Tuple, UOp] = {}
292
- replace: Dict[UOp, UOp] = {}
293
- def __inner_rewrite(n:UOp) -> UOp:
294
- if n in replace: return replace[n]
295
- replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
296
- if found := nodes.get(replace_source): replace[n] = found
297
- else: nodes[replace_source] = replace[n] = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
298
- return replace[n]
299
- return __inner_rewrite(sink)
300
-
301
- class UOpGraph:
302
- def __init__(self, sinks:List[UOp]):
303
- self.sinks: List[UOp] = sinks
304
- # used by linearizer
305
- self._uops: Optional[List[UOp]] = None
306
-
307
- def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
308
- def __getitem__(self, index) -> UOp: return self.uops[index]
309
-
310
- def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
311
- def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
312
-
313
- @property
314
- def uops(self):
315
- if self._uops is None: self.linearize()
316
- return self._uops
317
-
318
- def graph(self):
319
- from tinygrad.engine.graph import graph_uops
320
- graph_uops(self.uops)
321
-
322
- def print(self):
323
- for i,u in enumerate(self):
324
- print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
325
-
326
- def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
327
- # NOTE: relinearizering should be okay
328
- #assert self._uops is None, "already linearized"
329
- sink = UOp(UOps.SINK, None, tuple(self.sinks))
330
-
331
- # dedup all nodes and do graph rewrite
332
- sink = graph_rewrite(sink, constant_folder)
333
- if extra_pm: sink = graph_rewrite(sink, PatternMatcher(constant_folder.patterns+extra_pm.patterns))
334
-
335
- # filter nodes that don't link to a sink
336
- # BFS toposort
337
- children: Dict[UOp, List[UOp]] = {}
338
- in_degree: Dict[UOp, int] = {}
339
- get_children_dfs(sink, children, in_degree)
340
-
341
- @functools.lru_cache(None)
342
- def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
343
- if x.op is UOps.SINK: return set()
344
- return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
345
-
346
- # scope children impact the toposort and END* insertion
347
- end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
348
- loops, ifs = [x for x in in_degree if x.op is UOps.RANGE], [x for x in in_degree if x.op is UOps.IF]
349
- scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
350
-
351
- queue:List[Tuple[int, UOp]] = []
352
- def push(u:UOp):
353
- priority = 0
354
- # prefer uops that are loop children
355
- for l, ss in scope_children.items():
356
- if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
357
- heapq.heappush(queue, (priority, u))
358
-
359
- for u in children:
360
- if in_degree[u] == 0: push(u)
361
-
362
- if getenv("FUZZ_UOPS", 0):
363
- from test.external.fuzz_uops import fuzz_uops
364
- self.fuzz_paths = fuzz_uops(children, in_degree.copy(), scope_children)
365
-
366
- self._uops = []
367
- while queue:
368
- p,x = heapq.heappop(queue)
369
- if DEBUG >= 7: print(p,x)
370
- if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
371
- idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
372
- self._uops.insert(idx, x)
373
- else:
374
- self._uops.append(x)
375
- for u, ss in scope_children.items():
376
- if x in ss:
377
- ss.remove(x)
378
- if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
379
- for u in children[x]:
380
- in_degree[u] -= 1
381
- if in_degree[u] == 0: push(u)
382
-
383
- assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
384
- self._uops = self._uops[:-1]
385
-
386
- if type_verify: self.type_verify()
387
-
388
- # *** checker functions ***
247
+ def uop_alu_resolve(u:UOp) -> sint:
248
+ if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
249
+ if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
250
+ raise RuntimeError(f"ALU resolve fail @ {u.op}")
389
251
 
390
- def flops_mem(self, ignore_indexing=False) -> Tuple[sint, sint]:
391
- flops: sint = 0
392
- mem: sint = 0
393
- mults: sint = 1
394
- mult_stack = []
395
- dont_count: Set[UOp] = set()
396
- if ignore_indexing:
397
- for u in self.uops:
398
- if u.op is UOps.LOAD:
399
- dont_count = dont_count.union(u.src[1].sparents)
400
- if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
401
- elif u.op is UOps.STORE:
402
- dont_count = dont_count.union(u.src[1].sparents)
403
- if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
404
- for u in self.uops:
405
- if u.op is UOps.RANGE:
406
- mult_stack.append(mults)
407
- mults *= uop_alu_resolve(u.src[1])
408
- elif u.op is UOps.ENDRANGE:
409
- mults = mult_stack.pop(-1)
410
- elif u.op is UOps.LOAD:
411
- assert u.dtype is not None
412
- mem += u.dtype.itemsize * mults
252
+ def print_uops(uops:List[UOp]):
253
+ for i,u in enumerate(uops):
254
+ formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
255
+ print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
256
+
257
+ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
258
+ flops: sint = 0
259
+ mem: sint = 0
260
+ mults: sint = 1
261
+ mult_stack: List[sint] = []
262
+ dont_count: Set[UOp] = set()
263
+ if ignore_indexing:
264
+ for u in uops:
265
+ if u.op is UOps.LOAD:
266
+ dont_count = dont_count.union(u.src[1].sparents)
267
+ if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
413
268
  elif u.op is UOps.STORE:
414
- assert u.src[2].dtype is not None
415
- mem += u.src[2].dtype.itemsize * mults
416
- elif u.op is UOps.ALU and u not in dont_count:
417
- flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
418
- elif u.op is UOps.WMMA and u not in dont_count:
419
- assert u.arg[1] is not None
420
- flops += 2 * prod(u.arg[1]) // 32 * mults
421
- return flops, mem
422
-
423
- def type_verify(self):
424
- for u in self.uops:
425
- uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
426
- if uop in (UOps.CONST, UOps.DEFINE_ACC):
427
- if uop is UOps.DEFINE_ACC:
428
- assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
429
- arg = src[0].arg
430
- assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
431
- if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
432
- if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
433
- if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
434
- if uop is UOps.ALU:
435
- if arg in UnaryOps:
436
- assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
437
- elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
438
- assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
439
- assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
440
- elif arg is BinaryOps.IDIV:
441
- assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
442
- f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
443
- assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
444
- elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
445
- # the distance to shift isn't typechecked
446
- assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
447
- elif arg in BinaryOps:
448
- assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
449
- elif arg == TernaryOps.WHERE:
450
- assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
451
- assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
269
+ dont_count = dont_count.union(u.src[1].sparents)
270
+ if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
271
+ elif u.op is UOps.IF:
272
+ dont_count = dont_count.union(u.src[0].sparents)
273
+ for u in uops:
274
+ if u.op is UOps.RANGE:
275
+ mult_stack.append(mults)
276
+ mults *= uop_alu_resolve(u.src[1] - u.src[0])
277
+ elif u.op is UOps.ENDRANGE:
278
+ mults = mult_stack.pop(-1)
279
+ elif u.op is UOps.SPECIAL:
280
+ mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
281
+ elif u.op is UOps.LOAD:
282
+ assert u.dtype is not None
283
+ mem += u.dtype.itemsize * mults
284
+ elif u.op is UOps.STORE:
285
+ assert u.src[2].dtype is not None
286
+ mem += u.src[2].dtype.itemsize * mults
287
+ elif u.op is UOps.ALU and u not in dont_count:
288
+ assert u.dtype is not None
289
+ flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
290
+ elif u.op is UOps.WMMA and u not in dont_count:
291
+ assert u.arg[1] is not None
292
+ flops += 2 * prod(u.arg[1]) // 32 * mults
293
+ return flops, mem