tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type,
|
2
|
+
from typing import Any, Optional, Union, Callable, cast, TYPE_CHECKING, Type, get_args
|
3
3
|
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
4
4
|
from enum import auto, IntEnum, Enum
|
5
5
|
from dataclasses import dataclass, field
|
6
|
-
from collections import defaultdict
|
7
6
|
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
8
7
|
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
|
9
8
|
from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
|
@@ -89,11 +88,12 @@ class MathTrait(SimpleMathTrait):
|
|
89
88
|
def sin(self): return self.alu(Ops.SIN)
|
90
89
|
def log2(self): return self.alu(Ops.LOG2)
|
91
90
|
def exp2(self): return self.alu(Ops.EXP2)
|
91
|
+
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|
92
92
|
|
93
93
|
# the order of these Ops controls the order of the toposort
|
94
94
|
class Ops(FastEnum):
|
95
95
|
# uops that aren't rendered
|
96
|
-
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto();
|
96
|
+
NAME = auto(); SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); KERNEL = auto(); UNIQUE = auto() # noqa: E702
|
97
97
|
|
98
98
|
# TODO: empty continues to exist because of tensor
|
99
99
|
EMPTY = auto()
|
@@ -117,7 +117,7 @@ class Ops(FastEnum):
|
|
117
117
|
REDUCE_AXIS = auto()
|
118
118
|
|
119
119
|
# helper ops
|
120
|
-
GEP = auto(); VECTORIZE = auto() # noqa: E702
|
120
|
+
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
|
121
121
|
|
122
122
|
# UnaryOps
|
123
123
|
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
@@ -133,7 +133,7 @@ class Ops(FastEnum):
|
|
133
133
|
|
134
134
|
# BinaryOps
|
135
135
|
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
136
|
-
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
|
136
|
+
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
137
137
|
|
138
138
|
# TernaryOps
|
139
139
|
WHERE = auto(); MULACC = auto() # noqa: E702
|
@@ -151,18 +151,19 @@ class Ops(FastEnum):
|
|
151
151
|
# device
|
152
152
|
DEVICE = auto()
|
153
153
|
MULTI = auto()
|
154
|
+
CUSTOM = auto()
|
154
155
|
|
155
156
|
class GroupOp:
|
156
157
|
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
157
158
|
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
158
|
-
Ops.SUB, Ops.FDIV}
|
159
|
+
Ops.SUB, Ops.FDIV, Ops.POW}
|
159
160
|
Ternary = {Ops.WHERE, Ops.MULACC}
|
160
161
|
ALU = set.union(Unary, Binary, Ternary)
|
161
162
|
|
162
163
|
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
163
164
|
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
164
165
|
|
165
|
-
Buffer = {Ops.LOAD, Ops.
|
166
|
+
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
166
167
|
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
167
168
|
|
168
169
|
# BinaryOps that can be flipped
|
@@ -175,21 +176,21 @@ class GroupOp:
|
|
175
176
|
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
176
177
|
|
177
178
|
# do not preserve f(0) = 0
|
178
|
-
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
179
|
+
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
179
180
|
|
180
181
|
All = set(Ops)
|
181
182
|
|
182
183
|
# some BUFFER ops can be processed with only a view
|
183
|
-
view_supported_devices = {"LLVM", "
|
184
|
+
view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
184
185
|
|
185
186
|
# https://en.wikipedia.org/wiki/Identity_element
|
186
187
|
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
187
188
|
|
188
|
-
def can_pad(u:UOp, edges:dict[UOp,
|
189
|
+
def can_pad(u:UOp, edges:dict[UOp, None], cache:dict[UOp, None]) -> bool:
|
189
190
|
if u.op in GroupOp.UnsafePad: return False
|
190
|
-
if
|
191
|
-
|
192
|
-
return all(can_pad(x.base, edges,
|
191
|
+
if u in edges or u in cache: return True
|
192
|
+
cache[u] = None
|
193
|
+
return all(can_pad(x.base, edges, cache) for x in u.src)
|
193
194
|
|
194
195
|
# With True as the default, this matches the old symbolic behavior
|
195
196
|
def resolve(x:UOp|bool, default:bool=True):
|
@@ -289,6 +290,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
289
290
|
return ShapeTracker.from_shape(
|
290
291
|
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
291
292
|
if self.op is Ops.BUFFER: return ShapeTracker.from_shape((self.size,))
|
293
|
+
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
|
292
294
|
# these ops define a ShapeTracker from the arg
|
293
295
|
if self.op is Ops.VIEW: return self.arg
|
294
296
|
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
@@ -314,11 +316,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
314
316
|
@property
|
315
317
|
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
316
318
|
@property
|
317
|
-
def size(self) -> int: return self.arg
|
319
|
+
def size(self) -> int: return self.arg if self.op is Ops.BUFFER else unwrap(self.st).size
|
318
320
|
|
319
321
|
# *** uop evaluation ***
|
320
322
|
|
321
323
|
def simplify(self):
|
324
|
+
# late import!
|
325
|
+
from tinygrad.codegen.symbolic import symbolic
|
322
326
|
with Context(TRACK_MATCH_STATS=0):
|
323
327
|
return graph_rewrite(self, symbolic)
|
324
328
|
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
@@ -342,13 +346,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
342
346
|
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
|
343
347
|
return unwrap(self.st)
|
344
348
|
@property
|
345
|
-
def const_arg(self) -> ConstType:
|
346
|
-
match self.base.op:
|
347
|
-
case Ops.CONST: ret = self.base.arg
|
348
|
-
case op: raise AssertionError(f"const_arg called on {op}")
|
349
|
-
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
|
350
|
-
return ret
|
351
|
-
@property
|
352
349
|
def axis_arg(self) -> tuple[int, ...]:
|
353
350
|
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
354
351
|
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
@@ -366,8 +363,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
366
363
|
assert self.dtype.count == 1
|
367
364
|
if count == 1: return self
|
368
365
|
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
369
|
-
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
370
|
-
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
366
|
+
def cast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.CAST, dtype, (self,))
|
367
|
+
def bitcast(self, dtype:DType): return self if self.dtype == dtype else UOp(Ops.BITCAST, dtype, (self,))
|
371
368
|
def gep(self, i:Union[tuple[int, ...], int]):
|
372
369
|
if isinstance(i, int):
|
373
370
|
# NOTE: these are just shortcuts to not have to create and fold later
|
@@ -489,8 +486,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
489
486
|
if op is Ops.BIND:
|
490
487
|
var, val = arg.unbind()
|
491
488
|
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
492
|
-
# otherwise it's just a
|
493
|
-
|
489
|
+
# otherwise it's just a RESHAPE(BUFFER)
|
490
|
+
if not isinstance(size:=prod([x.vmax if isinstance(x, UOp) else x for x in shape]), int): raise ValueError(f"size must be int {size}")
|
491
|
+
return UOp.new_buffer(device, size, dtype).reshape(shape)
|
494
492
|
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
495
493
|
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
496
494
|
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
@@ -505,14 +503,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
505
503
|
return ret
|
506
504
|
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
507
505
|
@property
|
508
|
-
def metadata(self): return all_metadata.get(self, None)
|
506
|
+
def metadata(self) -> tuple[Metadata, ...]|Metadata|None: return self.arg.metadata if self.op is Ops.KERNEL else all_metadata.get(self, None)
|
509
507
|
|
510
508
|
# *** uop movement ops ***
|
511
509
|
|
512
510
|
@property
|
513
511
|
def base(self) -> UOp:
|
514
|
-
if self.op in GroupOp.Movement: return self.src[0].base
|
515
|
-
return self
|
512
|
+
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
513
|
+
return self
|
516
514
|
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
517
515
|
|
518
516
|
def _mop(self, op:Ops, arg):
|
@@ -527,11 +525,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
527
525
|
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
528
526
|
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
529
527
|
|
528
|
+
# *** uop UNIQUE ***
|
529
|
+
|
530
|
+
# TODO: use this in Buffer
|
531
|
+
unique_num = itertools.count(0)
|
532
|
+
@staticmethod
|
533
|
+
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
|
534
|
+
|
530
535
|
# *** uop Buffer stuff ***
|
531
536
|
|
532
|
-
buffer_num = itertools.count(0)
|
533
537
|
@staticmethod
|
534
|
-
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),
|
538
|
+
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device), UOp.unique()), size)
|
535
539
|
@property
|
536
540
|
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
537
541
|
@functools.cached_property
|
@@ -542,11 +546,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
542
546
|
@property
|
543
547
|
def buf_uop(self) -> UOp:
|
544
548
|
if self.base.op is Ops.BUFFER: return self.base
|
545
|
-
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN
|
549
|
+
assert self.base.op in {*GroupOp.Buffer, Ops.ASSIGN}, f"buf_uop called on {self.op}"
|
546
550
|
return self.src[0].buf_uop
|
547
551
|
@property
|
548
552
|
def buffer(self) -> Buffer:
|
549
|
-
if self
|
553
|
+
if self is not self.base:
|
550
554
|
assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous"
|
551
555
|
return self.src[0].buffer
|
552
556
|
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
@@ -591,6 +595,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
591
595
|
|
592
596
|
# *** uop symbolic stuff ***
|
593
597
|
|
598
|
+
def is_increasing(self:UOp) -> bool:
|
599
|
+
# is f a monotonically increasing function regards its input
|
600
|
+
if self.op in GroupOp.Irreducible: return True
|
601
|
+
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
|
602
|
+
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
|
603
|
+
return False # False if not sure
|
594
604
|
def const_factor(self) -> int:
|
595
605
|
"""largest known int that divides self"""
|
596
606
|
if self.op is Ops.CONST: return self.arg
|
@@ -598,7 +608,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
598
608
|
if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
599
609
|
if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
600
610
|
return 1
|
601
|
-
def divides(self, v) -> UOp|None:
|
611
|
+
def divides(self, v:int) -> UOp|None:
|
602
612
|
if v==1: return self
|
603
613
|
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
604
614
|
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
@@ -642,7 +652,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
642
652
|
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
643
653
|
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
644
654
|
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
645
|
-
# TODO:
|
655
|
+
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
|
646
656
|
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
647
657
|
if self.op is Ops.CONST: return self.arg, self.arg
|
648
658
|
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
@@ -665,20 +675,26 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|
665
675
|
|
666
676
|
@dataclass(frozen=True)
|
667
677
|
class KernelInfo:
|
678
|
+
name: str = "test" # name of the kernel
|
668
679
|
local_dims: int = 0 # number of local dimensions (this is remapping RANGE to SPECIAL)
|
669
680
|
upcasted: int = 0 # count that are upcasted (this is remapping RANGE to UNROLL)
|
670
681
|
dont_use_locals: bool = False # don't use local indexing
|
671
682
|
|
672
|
-
#
|
683
|
+
# ******** ops in python ********
|
673
684
|
|
674
685
|
def safe_exp2(x):
|
675
686
|
try: return 2 ** x
|
676
687
|
except OverflowError: return math.inf
|
677
688
|
|
689
|
+
def safe_pow(x, y):
|
690
|
+
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
|
691
|
+
except ZeroDivisionError: return math.inf
|
692
|
+
except ValueError: return math.inf if x > 0 else -math.inf
|
693
|
+
|
678
694
|
python_alu: dict[Ops, Callable] = {
|
679
695
|
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
|
680
696
|
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
681
|
-
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
697
|
+
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
|
682
698
|
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
683
699
|
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
684
700
|
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
|
@@ -703,7 +719,8 @@ def get_location() -> tuple[str, int]:
|
|
703
719
|
frm = sys._getframe(1)
|
704
720
|
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
705
721
|
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
706
|
-
"
|
722
|
+
"symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
|
723
|
+
"linearize.py"}:
|
707
724
|
frm = frm.f_back
|
708
725
|
return frm.f_code.co_filename, frm.f_lineno
|
709
726
|
@functools.lru_cache(None)
|
@@ -840,7 +857,9 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
|
840
857
|
class TrackedGraphRewrite:
|
841
858
|
loc: tuple[str, int] # location that called graph_rewrite
|
842
859
|
sink: UOp # the sink input to graph_rewrite
|
860
|
+
bottom_up: bool
|
843
861
|
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
862
|
+
name: Optional[str] = None
|
844
863
|
tracked_keys:list[Any] = []
|
845
864
|
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
846
865
|
_name_cnt:dict[str, int] = {}
|
@@ -923,304 +942,19 @@ class RewriteContext:
|
|
923
942
|
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
924
943
|
return ret
|
925
944
|
|
926
|
-
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
927
|
-
if TRACK_MATCH_STATS >= 2 and
|
928
|
-
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
945
|
+
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
946
|
+
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
947
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
929
948
|
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
930
949
|
|
931
|
-
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
932
|
-
if TRACK_MATCH_STATS >= 2 and
|
933
|
-
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
950
|
+
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
|
951
|
+
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
952
|
+
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
934
953
|
rewrite_ctx = RewriteContext(pm, ctx)
|
935
954
|
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
936
955
|
|
937
|
-
|
938
|
-
# *** most of symbolic lives here now ***
|
939
|
-
|
940
|
-
def split_uop(x:UOp, sep:Ops):
|
941
|
-
if x.op is sep:
|
942
|
-
for s in x.src: yield from split_uop(s, sep)
|
943
|
-
else: yield x
|
944
|
-
|
945
|
-
def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None:
|
946
|
-
# simplify x // y or x % y, None means no change
|
947
|
-
# simple cancel div/mod case
|
948
|
-
if y.vmin != 0 != y.vmax and (q:=x.vmin//y.vmin) == x.vmin//y.vmax == x.vmax//y.vmin == x.vmax//y.vmax:
|
949
|
-
return x - q*y if which is Ops.MOD else x.const_like(q)
|
950
|
-
|
951
|
-
if (y.op is not Ops.CONST) or ((c := y.arg) <= 0) or (x.dtype.count > 1): return None
|
952
|
-
|
953
|
-
svars, factors, quotients, remainders, gcd, div, const, offset, something_changed = [], [], [], [], c, 1, 0, 0, False
|
954
|
-
for u in split_uop(x, Ops.ADD):
|
955
|
-
if u.op is Ops.MOD and which is Ops.MOD and u.src[1].op is Ops.CONST and u.src[1].arg%c == 0:
|
956
|
-
u = u.src[0]
|
957
|
-
something_changed = True
|
958
|
-
v: UOp = u.divides(f:=u.const_factor())
|
959
|
-
q, r = divmod(f, c)
|
960
|
-
if r==0 or ((which is Ops.MOD or split_rem or u.op is Ops.CONST) and r!=f): something_changed = True
|
961
|
-
offset += r*v.vmin
|
962
|
-
if u.op is Ops.CONST: const += f
|
963
|
-
else: # div is the smallest common divisor of all terms
|
964
|
-
if f > 1 and c % f == 0 and (div == 1 or div > f): div = f
|
965
|
-
gcd = math.gcd(r, gcd)
|
966
|
-
factors.append(f); svars.append(v); quotients.append(q); remainders.append(r) # noqa: E702
|
967
|
-
|
968
|
-
lbound = ubound = offset = offset % c
|
969
|
-
# we can fold if the expression has only one non-constant term and this term can only take on two values
|
970
|
-
if len(svars)==1 and (v:=svars[0]).vmax-v.vmin == 1:
|
971
|
-
r = (offset+remainders[0])%c - offset%c
|
972
|
-
offset -= r * v.vmin
|
973
|
-
if which is Ops.MOD: return r*v + offset
|
974
|
-
return (factors[0]-r)//c * v + (const-offset)//c
|
975
|
-
|
976
|
-
# a//c = (a-a%c)/c, if we can fold a%c, we can fold a//c
|
977
|
-
# within a mod we can freely subtract multiples of c, we use this to see if a is congruent to an expression whose vmin/vmax are between 0 and c
|
978
|
-
for (r, v) in zip(remainders, svars):
|
979
|
-
if r > c//2:
|
980
|
-
if (lbound := lbound + (r:=r-c) * (v.vmax-v.vmin)) < 0: break
|
981
|
-
elif (ubound := ubound + r * (v.vmax-v.vmin)) >= c: break
|
982
|
-
offset -= r * v.vmin # determine what the new offset would be
|
983
|
-
else: # vmin/vmax of the remainder is between 0 and c, we can remove the mod/div
|
984
|
-
remainders = [min(r, r-c, key=abs) for r in remainders]
|
985
|
-
if which is Ops.MOD: return functools.reduce(operator.add, [r*v for r,v in zip(remainders,svars)], x.const_like(offset))
|
986
|
-
return functools.reduce(operator.add, [(f-r)//c * v for f,r,v in zip(factors, remainders,svars)], x.const_like((const-offset)//c))
|
987
|
-
|
988
|
-
if gcd != 1: something_changed = True
|
989
|
-
if not something_changed:
|
990
|
-
if which is Ops.IDIV and (1 < div < c) and (newx:=div_and_mod_folding(x, UOp.const(dtypes.int, div), Ops.IDIV)) is not None: return newx//(c//div)
|
991
|
-
return None
|
992
|
-
quo, rem = x.const_like(const//c), x.const_like((const%c)//gcd)
|
993
|
-
for q,r,f,v in zip(quotients, remainders, factors, svars):
|
994
|
-
if which is Ops.IDIV and (not split_rem) and r!=0:
|
995
|
-
rem += f//gcd * v
|
996
|
-
else:
|
997
|
-
rem += r//gcd * v
|
998
|
-
quo += q * v
|
999
|
-
|
1000
|
-
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
1001
|
-
return rem//(c//gcd)+quo
|
1002
|
-
|
1003
|
-
def lt_folding(x:UOp, c:int) -> UOp|None:
|
1004
|
-
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
1005
|
-
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
1006
|
-
return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d)
|
1007
|
-
return None
|
1008
|
-
|
1009
|
-
def fold_unrolled_divs(divs:UOp):
|
1010
|
-
# div pattern in unrolled arange
|
1011
|
-
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
1012
|
-
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
1013
|
-
for u in add_chain:
|
1014
|
-
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
1015
|
-
if denominator is None: denominator = u.src[1].arg
|
1016
|
-
if denominator != u.src[1].arg: return None
|
1017
|
-
# assumed CONST is the last of an ADD
|
1018
|
-
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
1019
|
-
seen_const.append(s0.src[1].arg)
|
1020
|
-
s0 = s0.src[0]
|
1021
|
-
else: seen_const.append(0)
|
1022
|
-
if ans is None: ans = s0
|
1023
|
-
if ans is not s0: return None
|
1024
|
-
if denominator is None: return None
|
1025
|
-
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
1026
|
-
for i in range(denominator-len(seen_const)):
|
1027
|
-
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
1028
|
-
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
1029
|
-
|
1030
|
-
def canonicalize_simplex(X:UOp) -> UOp|None:
|
1031
|
-
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
1032
|
-
# returns x0 + x1 + ... in such case, or None if not
|
1033
|
-
changed, ret = False, []
|
1034
|
-
for u in split_uop(X, Ops.ADD):
|
1035
|
-
# assumed the const is the last src of MUL
|
1036
|
-
if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
1037
|
-
changed = True
|
1038
|
-
u = u.src[0]
|
1039
|
-
if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None
|
1040
|
-
ret.append(u)
|
1041
|
-
return functools.reduce(operator.add, ret) if changed else None
|
1042
|
-
|
1043
|
-
def is_increasing(f:UOp) -> bool:
|
1044
|
-
# is f a monotonically increasing function regards its input
|
1045
|
-
if f.op in GroupOp.Irreducible: return True
|
1046
|
-
if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
1047
|
-
if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
1048
|
-
return False # False if not sure
|
1049
|
-
|
1050
|
-
def parse_valid(valid:UOp) -> tuple[UOp, bool, int]:
|
1051
|
-
# if it's X <= c, returns X, True, c
|
1052
|
-
# if it's X >= c, returns X, False, c
|
1053
|
-
|
1054
|
-
# (X < c).ne(True) -> X >= c
|
1055
|
-
if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
1056
|
-
(s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
1057
|
-
# X < c -> X <= c-1
|
1058
|
-
if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, valid.src[1].arg-1
|
1059
|
-
raise ValueError(f"not able to parse {valid=}")
|
1060
|
-
|
1061
|
-
def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None:
|
1062
|
-
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
1063
|
-
|
1064
|
-
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
1065
|
-
bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None])
|
1066
|
-
for stmt in split_uop(valid, Ops.AND):
|
1067
|
-
try: expr, is_upper, c = parse_valid(stmt)
|
1068
|
-
except ValueError: return uop # give up if we cannot parse the valid
|
1069
|
-
bounds[expr][int(is_upper)] = c
|
1070
|
-
|
1071
|
-
# simplify uop given that valid is True
|
1072
|
-
for expr,v in bounds.items():
|
1073
|
-
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
1074
|
-
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
1075
|
-
|
1076
|
-
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
1077
|
-
candidates = []
|
1078
|
-
if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)):
|
1079
|
-
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
1080
|
-
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
1081
|
-
# try checking the whole clause
|
1082
|
-
if expr in uop.toposort:
|
1083
|
-
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
1084
|
-
|
1085
|
-
for candidate in candidates:
|
1086
|
-
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
1087
|
-
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
1088
|
-
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
1089
|
-
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
1090
|
-
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
1091
|
-
elif all_same(newuops): uop = newuops[0]
|
1092
|
-
|
1093
|
-
return uop
|
1094
|
-
|
1095
|
-
def _valid_priority(v: UOp, valids:list[UOp]):
|
1096
|
-
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
1097
|
-
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
|
1098
|
-
except ValueError: return 0
|
1099
|
-
|
1100
|
-
def simplify_valid(valid:UOp) -> UOp|None:
|
1101
|
-
ret:list[UOp] = []
|
1102
|
-
something_changed = False
|
1103
|
-
valids = list(split_uop(valid, Ops.AND))
|
1104
|
-
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
1105
|
-
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
1106
|
-
if ret[-1] is not stmt: something_changed = True
|
1107
|
-
return functools.reduce(operator.and_, ret) if something_changed else None
|
1108
|
-
|
1109
|
-
# def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
1110
|
-
# if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
1111
|
-
# if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
1112
|
-
|
1113
956
|
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
1114
957
|
|
1115
|
-
symbolic_simple = PatternMatcher([
|
1116
|
-
# ** self folding **
|
1117
|
-
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
1118
|
-
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
1119
|
-
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
1120
|
-
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
1121
|
-
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
1122
|
-
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
1123
|
-
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
1124
|
-
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
1125
|
-
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
1126
|
-
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
1127
|
-
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
1128
|
-
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
1129
|
-
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
1130
|
-
(UPat(GroupOp.Idempotent, src=(UPat.var("x"), UPat.var("x"))), lambda x: x),
|
1131
|
-
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
1132
|
-
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
1133
|
-
# ** zero folding **
|
1134
|
-
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
1135
|
-
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints),
|
1136
|
-
lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x != x -> False (only ints)
|
1137
|
-
# x*0 -> 0 or 0*x -> 0
|
1138
|
-
# if x is nan or inf it should render the nan value.
|
1139
|
-
# NOTE: this can be wrong for loaded NaN
|
1140
|
-
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
1141
|
-
# ** constant folding **
|
1142
|
-
# TODO: add const folding for Ops.THREEFRY
|
1143
|
-
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
|
1144
|
-
lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
|
1145
|
-
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
1146
|
-
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
1147
|
-
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
1148
|
-
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
1149
|
-
# *** cast ***
|
1150
|
-
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
1151
|
-
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
1152
|
-
])
|
1153
|
-
|
1154
|
-
symbolic = symbolic_simple+PatternMatcher([
|
1155
|
-
# ** COMMUTATIVE flipping **
|
1156
|
-
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
1157
|
-
# ** boolean algebra **
|
1158
|
-
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
1159
|
-
# ** combine terms **
|
1160
|
-
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
1161
|
-
((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)),
|
1162
|
-
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
1163
|
-
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
|
1164
|
-
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
1165
|
-
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
|
1166
|
-
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
|
1167
|
-
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
|
1168
|
-
# a conditional with the same results either way is a noop, also fold const conditionals
|
1169
|
-
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
1170
|
-
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
1171
|
-
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
1172
|
-
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
1173
|
-
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
1174
|
-
# ALU min==max -> CONST (slow!)
|
1175
|
-
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
1176
|
-
# max folding
|
1177
|
-
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
1178
|
-
# TODO: why does this rule break beautiful_mnist?
|
1179
|
-
#((UPat.var("x")+UPat.var("z")).maximum(UPat.var("y")+UPat.var("z")), lambda x,y,z: x.maximum(y) + z),
|
1180
|
-
#((UPat.var("x")*UPat.cvar("c1")).maximum(UPat.var("x")*UPat.cvar("c2")), max_var_const),
|
1181
|
-
# ** two stage ALU folding **
|
1182
|
-
*((UPat.var("x").alu(op, UPat.cvar("c1")).alu(op, UPat.cvar("c2")).named("f"),
|
1183
|
-
lambda f,x,c1,c2: x.alu(f.op,c1.alu(f.op,c2))) for op in GroupOp.Associative),
|
1184
|
-
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
1185
|
-
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
1186
|
-
# ** lt **
|
1187
|
-
# c0*x<c1 for positive int c0,c1
|
1188
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
1189
|
-
lambda x,c0,c1: x<math.ceil(c1.arg/c0.arg) if c0.arg > 0 and c1.arg > 0 else None),
|
1190
|
-
# c0*x<c1 for negative int c0 and non-positive c1
|
1191
|
-
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints))<UPat.cvar("c1", vec=False),
|
1192
|
-
lambda x,c0,c1: (-x)<(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
1193
|
-
# x//c0<c1 for positive int c0
|
1194
|
-
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False))<UPat.cvar("c1", vec=False),
|
1195
|
-
lambda x,c0,c1: x<(c1.arg*c0.arg) if c0.arg > 0 else None),
|
1196
|
-
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
1197
|
-
(UPat(Ops.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
1198
|
-
(UPat(Ops.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
1199
|
-
# *** rules from symbolic ***
|
1200
|
-
# unrolled arange div folding
|
1201
|
-
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
1202
|
-
# generic lt folding
|
1203
|
-
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
1204
|
-
# canonicalize a simplex with positive coefficients > 0
|
1205
|
-
# not x < 1 -> X > 0
|
1206
|
-
((UPat.var("x", dtypes.ints)<1).ne(True), lambda x: (newx<1).ne(True) if (newx:=canonicalize_simplex(x)) is not None else None),
|
1207
|
-
# ** div **
|
1208
|
-
# div folding
|
1209
|
-
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
|
1210
|
-
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
1211
|
-
# ** mod **
|
1212
|
-
# mod folding
|
1213
|
-
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
1214
|
-
])
|
1215
|
-
|
1216
|
-
|
1217
|
-
symbolic_flat = symbolic+PatternMatcher([
|
1218
|
-
# ** combine terms (opinionated) **
|
1219
|
-
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
1220
|
-
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
1221
|
-
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
1222
|
-
])
|
1223
|
-
|
1224
958
|
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
1225
959
|
|
1226
960
|
# for debug
|
@@ -1250,7 +984,8 @@ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
|
1250
984
|
merge_views = PatternMatcher([
|
1251
985
|
# VIEW(VIEW) merges to a single VIEW
|
1252
986
|
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
|
1253
|
-
|
987
|
+
# remove VIEW if it's contiguous and same as the base shape
|
988
|
+
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
|
1254
989
|
# merge unmasked const views
|
1255
990
|
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
1256
991
|
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
tinygrad/renderer/__init__.py
CHANGED
@@ -1,11 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Optional, Callable
|
3
3
|
import functools, math
|
4
|
+
from enum import Enum, auto
|
4
5
|
from dataclasses import dataclass, field, replace
|
5
6
|
from tinygrad.helpers import to_function_name, dedup, prod
|
6
7
|
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
7
8
|
from tinygrad.dtype import DType
|
8
9
|
|
10
|
+
class OptOps(Enum):
|
11
|
+
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
12
|
+
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
13
|
+
def __lt__(self, x:OptOps): return self.value < x.value
|
14
|
+
|
15
|
+
@dataclass(frozen=True, order=True)
|
16
|
+
class Opt:
|
17
|
+
op: OptOps
|
18
|
+
axis: Optional[int] = None
|
19
|
+
arg: Optional[int | tuple] = None
|
20
|
+
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
21
|
+
|
9
22
|
@dataclass(frozen=True)
|
10
23
|
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
11
24
|
dims: tuple[int,int,int] # N, M, K
|
@@ -70,7 +83,9 @@ class ProgramSpec:
|
|
70
83
|
name:str
|
71
84
|
src:str
|
72
85
|
device:str
|
86
|
+
ast:UOp # save the base ast (this is method cache key)
|
73
87
|
uops:Optional[list[UOp]]=None
|
88
|
+
applied_opts:Optional[list[Opt]]=None
|
74
89
|
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
75
90
|
|
76
91
|
# filled in from uops (if we have uops)
|
@@ -121,12 +136,13 @@ class Renderer:
|
|
121
136
|
has_local: bool = True
|
122
137
|
has_shared: bool = True
|
123
138
|
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
124
|
-
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO:
|
125
|
-
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO:
|
139
|
+
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
140
|
+
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
126
141
|
shared_max: int = 32768
|
127
142
|
tensor_cores: list[TensorCore] = []
|
143
|
+
pre_matcher: Optional[PatternMatcher] = None
|
128
144
|
extra_matcher: Optional[PatternMatcher] = None
|
129
145
|
code_for_op: dict[Ops, Callable] = {}
|
130
146
|
|
131
147
|
def __reduce__(self): return self.__class__, ()
|
132
|
-
def render(self,
|
148
|
+
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|