tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,9 +1,8 @@
1
1
  from __future__ import annotations
2
- from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, Literal, get_args
2
+ from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
3
3
  import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
4
4
  from enum import auto, IntEnum, Enum
5
5
  from dataclasses import dataclass, field
6
- from collections import defaultdict
7
6
  from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
8
7
  from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
9
8
  from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
@@ -89,11 +88,12 @@ class MathTrait(SimpleMathTrait):
89
88
  def sin(self): return self.alu(Ops.SIN)
90
89
  def log2(self): return self.alu(Ops.LOG2)
91
90
  def exp2(self): return self.alu(Ops.EXP2)
91
+ def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
92
92
 
93
93
  # the order of these Ops controls the order of the toposort
94
94
  class Ops(FastEnum):
95
95
  # uops that aren't rendered
96
- SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto(); KERNEL = auto() # noqa: E702
96
+ NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
97
97
 
98
98
  # TODO: empty continues to exist because of tensor
99
99
  EMPTY = auto()
@@ -117,7 +117,7 @@ class Ops(FastEnum):
117
117
  REDUCE_AXIS = auto()
118
118
 
119
119
  # helper ops
120
- GEP = auto(); VECTORIZE = auto() # noqa: E702
120
+ GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
121
121
 
122
122
  # UnaryOps
123
123
  CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
@@ -133,7 +133,7 @@ class Ops(FastEnum):
133
133
 
134
134
  # BinaryOps
135
135
  ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
136
- SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
136
+ SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
137
137
 
138
138
  # TernaryOps
139
139
  WHERE = auto(); MULACC = auto() # noqa: E702
@@ -151,18 +151,19 @@ class Ops(FastEnum):
151
151
  # device
152
152
  DEVICE = auto()
153
153
  MULTI = auto()
154
+ CUSTOM = auto()
154
155
 
155
156
  class GroupOp:
156
157
  Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
157
158
  Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
158
- Ops.SUB, Ops.FDIV}
159
+ Ops.SUB, Ops.FDIV, Ops.POW}
159
160
  Ternary = {Ops.WHERE, Ops.MULACC}
160
161
  ALU = set.union(Unary, Binary, Ternary)
161
162
 
162
163
  Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
163
164
  Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
164
165
 
165
- Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
166
+ Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
166
167
  Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
167
168
 
168
169
  # BinaryOps that can be flipped
@@ -175,21 +176,21 @@ class GroupOp:
175
176
  Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
176
177
 
177
178
  # do not preserve f(0) = 0
178
- UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
179
+ UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
179
180
 
180
181
  All = set(Ops)
181
182
 
182
183
  # some BUFFER ops can be processed with only a view
183
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
184
+ view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
184
185
 
185
186
  # https://en.wikipedia.org/wiki/Identity_element
186
187
  def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
187
188
 
188
- def can_pad(u:UOp, edges:dict[UOp, UOp], visisted:dict[UOp, None]) -> bool:
189
+ def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
189
190
  if u.op in GroupOp.UnsafePad: return False
190
- if (len(u.src) == 2 and u.src[0] in edges) or u in visisted: return True
191
- visisted[u] = None
192
- return all(can_pad(x.base, edges, visisted) for x in u.src)
191
+ if u in edges or u in cache: return True
192
+ cache[u] = None
193
+ return all(can_pad(x.base, edges, cache) for x in u.src)
193
194
 
194
195
  # With True as the default, this matches the old symbolic behavior
195
196
  def resolve(x:UOp|bool, default:bool=True):
@@ -289,6 +290,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
289
290
  return ShapeTracker.from_shape(
290
291
  tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
291
292
  if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
293
+ if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
292
294
  # these ops define a ShapeTracker from the arg
293
295
  if self.op is Ops.VIEW: return self.arg
294
296
  if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
@@ -314,11 +316,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
314
316
  @property
315
317
  def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
316
318
  @property
317
- def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size
319
+ def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
318
320
 
319
321
  # *** uop evaluation ***
320
322
 
321
323
  def simplify(self):
324
+ # late import!
325
+ from tinygrad.codegen.symbolic import symbolic
322
326
  with Context(TRACK_MATCH_STATS=0):
323
327
  return graph_rewrite(self, symbolic)
324
328
  def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
@@ -342,13 +346,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
342
346
  assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
343
347
  return unwrap(self.st)
344
348
  @property
345
- def const_arg(self) -> ConstType:
346
- match self.base.op:
347
- case Ops.CONST: ret = self.base.arg
348
- case op: raise AssertionError(f"const_arg called on {op}")
349
- assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
350
- return ret
351
- @property
352
349
  def axis_arg(self) -> tuple[int, ...]:
353
350
  assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
354
351
  ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
@@ -366,8 +363,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
366
363
  assert self.dtype.count == 1
367
364
  if count == 1: return self
368
365
  return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
369
- def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
370
- def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
366
+ def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
367
+ def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
371
368
  def gep(self, i:Union[tuple[int, ...], int]):
372
369
  if isinstance(i, int):
373
370
  # NOTE: these are just shortcuts to not have to create and fold later
@@ -489,8 +486,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
489
486
  if op is Ops.BIND:
490
487
  var, val = arg.unbind()
491
488
  return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
492
- # otherwise it's just a VIEW(BUFFER)
493
- return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
489
+ # otherwise it's just a RESHAPE(BUFFER)
490
+ if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
491
+ return UOp.new_buffer(device, size, dtype).reshape(shape)
494
492
  def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
495
493
  # if it's a shrink, do the shrink before the copy with CONTIGUOUS
496
494
  if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
@@ -505,14 +503,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
505
503
  return ret
506
504
  def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
507
505
  @property
508
- def metadata(self): return all_metadata.get(self, None)
506
+ def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
509
507
 
510
508
  # *** uop movement ops ***
511
509
 
512
510
  @property
513
511
  def base(self) -> UOp:
514
- if self.op in GroupOp.Movement: return self.src[0].base
515
- return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 else self
512
+ if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
513
+ return self
516
514
  def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
517
515
 
518
516
  def _mop(self, op:Ops, arg):
@@ -527,11 +525,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
527
525
  def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
528
526
  def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
529
527
 
528
+ # *** uop UNIQUE ***
529
+
530
+ # TODO: use this in Buffer
531
+ unique_num = itertools.count(0)
532
+ @staticmethod
533
+ def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
534
+
530
535
  # *** uop Buffer stuff ***
531
536
 
532
- buffer_num = itertools.count(0)
533
537
  @staticmethod
534
- def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
538
+ def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
535
539
  @property
536
540
  def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
537
541
  @functools.cached_property
@@ -542,11 +546,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
542
546
  @property
543
547
  def buf_uop(self) -> UOp:
544
548
  if self.base.op is Ops.BUFFER: return self.base
545
- assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW}, f"buf_uop called on {self.op}"
549
+ assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
546
550
  return self.src[0].buf_uop
547
551
  @property
548
552
  def buffer(self) -> Buffer:
549
- if self.op is Ops.VIEW:
553
+ if self is not self.base:
550
554
  assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
551
555
  return self.src[0].buffer
552
556
  assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
@@ -591,6 +595,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
591
595
 
592
596
  # *** uop symbolic stuff ***
593
597
 
598
+ def is_increasing(self:UOp) -> bool:
599
+ # is f a monotonically increasing function regards its input
600
+ if self.op in GroupOp.Irreducible: return True
601
+ if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
602
+ if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
603
+ return False # False if not sure
594
604
  def const_factor(self) -> int:
595
605
  """largest known int that divides self"""
596
606
  if self.op is Ops.CONST: return self.arg
@@ -598,7 +608,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
598
608
  if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
599
609
  if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
600
610
  return 1
601
- def divides(self, v) -> UOp|None:
611
+ def divides(self, v:int) -> UOp|None:
602
612
  if v==1: return self
603
613
  if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
604
614
  if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
@@ -642,7 +652,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
642
652
  if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
643
653
  if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
644
654
  if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
645
- # TODO: UOps.SPECIAL is UOps.DEFINE_VAR
655
+ # TODO: Ops.SPECIAL is Ops.DEFINE_VAR
646
656
  if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
647
657
  if self.op is Ops.CONST: return self.arg, self.arg
648
658
  if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
@@ -665,20 +675,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
665
675
 
666
676
  @dataclass(frozen=True)
667
677
  class KernelInfo:
678
+ name: str = "test" # name of the kernel
668
679
  local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
669
680
  upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
670
681
  dont_use_locals: bool = False # don't use local indexing
671
682
 
672
- # ***** ops in python *****
683
+ # ******** ops in python ********
673
684
 
674
685
  def safe_exp2(x):
675
686
  try: return 2 ** x
676
687
  except OverflowError: return math.inf
677
688
 
689
+ def safe_pow(x, y):
690
+ try: return math.nan if isinstance(p:=pow(x, y), complex) else p
691
+ except ZeroDivisionError: return math.inf
692
+ except ValueError: return math.inf if x > 0 else -math.inf
693
+
678
694
  python_alu: dict[Ops, Callable] = {
679
695
  Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
680
696
  Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
681
- Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
697
+ Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
682
698
  Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
683
699
  Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
684
700
  Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
@@ -703,7 +719,8 @@ def get_location() -> tuple[str, int]:
703
719
  frm = sys._getframe(1)
704
720
  # find the real frame in the file that has the UPat, TODO: is there a better way to do this?
705
721
  while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
706
- "lowerer.py", "cstyle.py", "linearize.py"}:
722
+ "symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
723
+ "linearize.py"}:
707
724
  frm = frm.f_back
708
725
  return frm.f_code.co_filename, frm.f_lineno
709
726
  @functools.lru_cache(None)
@@ -840,7 +857,9 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
840
857
  class TrackedGraphRewrite:
841
858
  loc: tuple[str, int] # location that called graph_rewrite
842
859
  sink: UOp # the sink input to graph_rewrite
860
+ bottom_up: bool
843
861
  matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
862
+ name: Optional[str] = None
844
863
  tracked_keys:list[Any] = []
845
864
  tracked_ctxs:list[list[TrackedGraphRewrite]] = []
846
865
  _name_cnt:dict[str, int] = {}
@@ -923,304 +942,19 @@ class RewriteContext:
923
942
  self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
924
943
  return ret
925
944
 
926
- def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
927
- if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
928
- tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
945
+ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
946
+ if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
947
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
929
948
  return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
930
949
 
931
- def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
932
- if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
933
- tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
950
+ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
951
+ if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
952
+ tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
934
953
  rewrite_ctx = RewriteContext(pm, ctx)
935
954
  return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
936
955
 
937
-
938
- # *** most of symbolic lives here now ***
939
-
940
- def split_uop(x:UOp, sep:Ops):
941
- if x.op is sep:
942
- for s in x.src: yield from split_uop(s, sep)
943
- else: yield x
944
-
945
- def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
946
- # simplify x // y or x % y, None means no change
947
- # simple cancel div/mod case
948
- if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
949
- return x - q*y if which is Ops.MOD else x.const_like(q)
950
-
951
- if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
952
-
953
- svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
954
- for u in split_uop(x, Ops.ADD):
955
- if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
956
- u = u.src[0]
957
- something_changed = True
958
- v: UOp = u.divides(f:=u.const_factor())
959
- q, r = divmod(f, c)
960
- if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
961
- offset += r*v.vmin
962
- if u.op is Ops.CONST: const += f
963
- else: # div is the smallest common divisor of all terms
964
- if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
965
- gcd = math.gcd(r, gcd)
966
- factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
967
-
968
- lbound = ubound = offset = offset % c
969
- # we can fold if the expression has only one non-constant term and this term can only take on two values
970
- if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
971
- r = (offset+remainders[0])%c - offset%c
972
- offset -= r * v.vmin
973
- if which is Ops.MOD: return r*v + offset
974
- return (factors[0]-r)//c * v + (const-offset)//c
975
-
976
- # a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
977
- # within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
978
- for (r, v) in zip(remainders, svars):
979
- if r > c//2:
980
- if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
981
- elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
982
- offset -= r * v.vmin # determine what the new offset would be
983
- else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
984
- remainders = [min(r, r-c, key=abs) for r in remainders]
985
- if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
986
- return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
987
-
988
- if gcd != 1: something_changed = True
989
- if not something_changed:
990
- if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
991
- return None
992
- quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
993
- for q,r,f,v in zip(quotients, remainders, factors, svars):
994
- if which is Ops.IDIV and (not split_rem) and r!=0:
995
- rem += f//gcd * v
996
- else:
997
- rem += r//gcd * v
998
- quo += q * v
999
-
1000
- if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
1001
- return rem//(c//gcd)+quo
1002
-
1003
- def lt_folding(x:UOp, c:int) -> UOp|None:
1004
- p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
1005
- if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
1006
- return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
1007
- return None
1008
-
1009
- def fold_unrolled_divs(divs:UOp):
1010
- # div pattern in unrolled arange
1011
- # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
1012
- add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
1013
- for u in add_chain:
1014
- if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
1015
- if denominator is None: denominator = u.src[1].arg
1016
- if denominator != u.src[1].arg: return None
1017
- # assumed CONST is the last of an ADD
1018
- if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
1019
- seen_const.append(s0.src[1].arg)
1020
- s0 = s0.src[0]
1021
- else: seen_const.append(0)
1022
- if ans is None: ans = s0
1023
- if ans is not s0: return None
1024
- if denominator is None: return None
1025
- # the first (denominator-len(seen_const)) terms may have been folded to 0 already
1026
- for i in range(denominator-len(seen_const)):
1027
- if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
1028
- return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
1029
-
1030
- def canonicalize_simplex(X:UOp) -> UOp|None:
1031
- # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
1032
- # returns x0 + x1 + ... in such case, or None if not
1033
- changed, ret = False, []
1034
- for u in split_uop(X, Ops.ADD):
1035
- # assumed the const is the last src of MUL
1036
- if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
1037
- changed = True
1038
- u = u.src[0]
1039
- if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
1040
- ret.append(u)
1041
- return functools.reduce(operator.add, ret) if changed else None
1042
-
1043
- def is_increasing(f:UOp) -> bool:
1044
- # is f a monotonically increasing function regards its input
1045
- if f.op in GroupOp.Irreducible: return True
1046
- if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
1047
- if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
1048
- return False # False if not sure
1049
-
1050
- def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
1051
- # if it's X <= c, returns X, True, c
1052
- # if it's X >= c, returns X, False, c
1053
-
1054
- # (X < c).ne(True) -> X >= c
1055
- if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
1056
- (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
1057
- # X < c -> X <= c-1
1058
- if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
1059
- raise ValueError(f"not able to parse {valid=}")
1060
-
1061
- def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
1062
- # return None if valid is always False, otherwise the simplified uop (might be the same as input)
1063
-
1064
- # first, parse valid into {expr: (lower_bound, upper_bound)}
1065
- bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
1066
- for stmt in split_uop(valid, Ops.AND):
1067
- try: expr, is_upper, c = parse_valid(stmt)
1068
- except ValueError: return uop # give up if we cannot parse the valid
1069
- bounds[expr][int(is_upper)] = c
1070
-
1071
- # simplify uop given that valid is True
1072
- for expr,v in bounds.items():
1073
- # some expr has lower bound > upper bound -> valid is an empty set and we return None
1074
- if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
1075
-
1076
- # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
1077
- candidates = []
1078
- if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
1079
- # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
1080
- candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
1081
- # try checking the whole clause
1082
- if expr in uop.toposort:
1083
- candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
1084
-
1085
- for candidate in candidates:
1086
- # if every branch in candidate gives the same simplified uop, we can rewrite the uop
1087
- newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
1088
- if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
1089
- if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
1090
- if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
1091
- elif all_same(newuops): uop = newuops[0]
1092
-
1093
- return uop
1094
-
1095
- def _valid_priority(v: UOp, valids:list[UOp]):
1096
- # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
1097
- try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
1098
- except ValueError: return 0
1099
-
1100
- def simplify_valid(valid:UOp) -> UOp|None:
1101
- ret:list[UOp] = []
1102
- something_changed = False
1103
- valids = list(split_uop(valid, Ops.AND))
1104
- for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
1105
- ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
1106
- if ret[-1] is not stmt: something_changed = True
1107
- return functools.reduce(operator.and_, ret) if something_changed else None
1108
-
1109
- # def max_var_const(x:UOp, c1:UOp, c2:UOp):
1110
- # if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
1111
- # if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
1112
-
1113
956
  def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
1114
957
 
1115
- symbolic_simple = PatternMatcher([
1116
- # ** self folding **
1117
- (UPat.var("x") + 0, lambda x: x), # x+0 -> x
1118
- (UPat.var("x") * 1, lambda x: x), # x*1 -> x
1119
- (UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
1120
- (UPat.var("x") // 1, lambda x: x), # x//1 -> x
1121
- (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
1122
- (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
1123
- ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
1124
- ((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
1125
- (UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
1126
- ((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
1127
- lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
1128
- (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
1129
- (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
1130
- (UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
1131
- (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
1132
- (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
1133
- # ** zero folding **
1134
- (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
1135
- (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
1136
- lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
1137
- # x*0 -> 0 or 0*x -> 0
1138
- # if x is nan or inf it should render the nan value.
1139
- # NOTE: this can be wrong for loaded NaN
1140
- (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
1141
- # ** constant folding **
1142
- # TODO: add const folding for Ops.THREEFRY
1143
- (UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
1144
- lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
1145
- # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
1146
- (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
1147
- (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
1148
- (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
1149
- # *** cast ***
1150
- (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
1151
- (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
1152
- ])
1153
-
1154
- symbolic = symbolic_simple+PatternMatcher([
1155
- # ** COMMUTATIVE flipping **
1156
- (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
1157
- # ** boolean algebra **
1158
- (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
1159
- # ** combine terms **
1160
- (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
1161
- ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
1162
- (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
1163
- ((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
1164
- (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
1165
- ((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
1166
- ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
1167
- (-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
1168
- # a conditional with the same results either way is a noop, also fold const conditionals
1169
- (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
1170
- (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
1171
- # alu of two where with same conds can combine, only do if true branch or false branch is const
1172
- (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
1173
- lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
1174
- # ALU min==max -> CONST (slow!)
1175
- (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
1176
- # max folding
1177
- (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
1178
- # TODO: why does this rule break beautiful_mnist?
1179
- #((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
1180
- #((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
1181
- # ** two stage ALU folding **
1182
- *((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
1183
- lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
1184
- ((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
1185
- ((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
1186
- # ** lt **
1187
- # c0*x<c1 for positive int c0,c1
1188
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
1189
- lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
1190
- # c0*x<c1 for negative int c0 and non-positive c1
1191
- ((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
1192
- lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
1193
- # x//c0<c1 for positive int c0
1194
- ((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
1195
- lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
1196
- # ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
1197
- (UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
1198
- (UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
1199
- # *** rules from symbolic ***
1200
- # unrolled arange div folding
1201
- (UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
1202
- # generic lt folding
1203
- (UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
1204
- # canonicalize a simplex with positive coefficients > 0
1205
- # not x < 1 -> X > 0
1206
- ((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
1207
- # ** div **
1208
- # div folding
1209
- ((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
1210
- (UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
1211
- # ** mod **
1212
- # mod folding
1213
- (UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
1214
- ])
1215
-
1216
-
1217
- symbolic_flat = symbolic+PatternMatcher([
1218
- # ** combine terms (opinionated) **
1219
- (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
1220
- # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
1221
- ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
1222
- ])
1223
-
1224
958
  _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
1225
959
 
1226
960
  # for debug
@@ -1250,7 +984,8 @@ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
1250
984
  merge_views = PatternMatcher([
1251
985
  # VIEW(VIEW) merges to a single VIEW
1252
986
  (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
1253
- (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
987
+ # remove VIEW if it's contiguous and same as the base shape
988
+ (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
1254
989
  # merge unmasked const views
1255
990
  (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
1256
991
  lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
@@ -1,11 +1,24 @@
1
1
  from __future__ import annotations
2
2
  from typing import Optional, Callable
3
3
  import functools, math
4
+ from enum import Enum, auto
4
5
  from dataclasses import dataclass, field, replace
5
6
  from tinygrad.helpers import to_function_name, dedup, prod
6
7
  from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
7
8
  from tinygrad.dtype import DType
8
9
 
10
+ class OptOps(Enum):
11
+ TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
12
+ GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
13
+ def __lt__(self, x:OptOps): return self.value < x.value
14
+
15
+ @dataclass(frozen=True, order=True)
16
+ class Opt:
17
+ op: OptOps
18
+ axis: Optional[int] = None
19
+ arg: Optional[int | tuple] = None
20
+ def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
21
+
9
22
  @dataclass(frozen=True)
10
23
  class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
11
24
  dims: tuple[int,int,int] # N, M, K
@@ -70,7 +83,9 @@ class ProgramSpec:
70
83
  name:str
71
84
  src:str
72
85
  device:str
86
+ ast:UOp # save the base ast (this is method cache key)
73
87
  uops:Optional[list[UOp]]=None
88
+ applied_opts:Optional[list[Opt]]=None
74
89
  mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
75
90
 
76
91
  # filled in from uops (if we have uops)
@@ -121,12 +136,13 @@ class Renderer:
121
136
  has_local: bool = True
122
137
  has_shared: bool = True
123
138
  # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
124
- global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
125
- local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
139
+ global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
140
+ local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
126
141
  shared_max: int = 32768
127
142
  tensor_cores: list[TensorCore] = []
143
+ pre_matcher: Optional[PatternMatcher] = None
128
144
  extra_matcher: Optional[PatternMatcher] = None
129
145
  code_for_op: dict[Ops, Callable] = {}
130
146
 
131
147
  def __reduce__(self): return self.__class__, ()
132
- def render(self, name:str, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
148
+ def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")