tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/uop/__init__.py
ADDED
@@ -0,0 +1,117 @@
|
|
1
|
+
from enum import auto, IntEnum, Enum
|
2
|
+
|
3
|
+
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
4
|
+
class FastEnum(IntEnum):
|
5
|
+
def __str__(self): return Enum.__str__(self)
|
6
|
+
@staticmethod
|
7
|
+
def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
|
8
|
+
|
9
|
+
# the order of these Ops controls the order of the toposort
|
10
|
+
class Ops(FastEnum):
|
11
|
+
# uops that aren't rendered
|
12
|
+
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
|
13
|
+
|
14
|
+
# track children
|
15
|
+
CHILD = auto(); CHILDREN = auto() # noqa: E702
|
16
|
+
|
17
|
+
# buffer ops
|
18
|
+
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
19
|
+
|
20
|
+
# create buffer
|
21
|
+
BUFFERIZE = auto()
|
22
|
+
|
23
|
+
# ops that adjust the behavior of the scheduler
|
24
|
+
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
25
|
+
|
26
|
+
# blocks in linearizer (only used there)
|
27
|
+
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
28
|
+
|
29
|
+
# movement ops! these only exist in the tensor graph
|
30
|
+
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
31
|
+
MULTI = auto() # MULTI is really a movement op
|
32
|
+
|
33
|
+
# view is what all movement ops become
|
34
|
+
VIEW = auto()
|
35
|
+
|
36
|
+
# TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
|
37
|
+
VALID = auto()
|
38
|
+
|
39
|
+
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
|
40
|
+
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
|
41
|
+
|
42
|
+
# this is for symbolic shapes
|
43
|
+
DEFINE_VAR = auto(); BIND = auto() # noqa: E702
|
44
|
+
|
45
|
+
# this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
|
46
|
+
SPECIAL = auto()
|
47
|
+
|
48
|
+
# reduce
|
49
|
+
REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
|
50
|
+
|
51
|
+
# optimization helper ops
|
52
|
+
UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
53
|
+
|
54
|
+
# UnaryOps
|
55
|
+
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
|
56
|
+
|
57
|
+
# load/store before math
|
58
|
+
LOAD = auto(); STORE = auto() # noqa: E702
|
59
|
+
ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
|
60
|
+
|
61
|
+
# tensor core math op, not elementwise
|
62
|
+
WMMA = auto()
|
63
|
+
|
64
|
+
# INDEX is a BinaryOp similar to ADD, but it operates on pointers
|
65
|
+
INDEX = auto()
|
66
|
+
|
67
|
+
# BinaryOps
|
68
|
+
ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
|
69
|
+
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
|
70
|
+
XOR = auto(); OR = auto(); AND = auto() # noqa: E702
|
71
|
+
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
72
|
+
|
73
|
+
# TernaryOps
|
74
|
+
WHERE = auto(); MULACC = auto() # noqa: E702
|
75
|
+
|
76
|
+
# control flow ops
|
77
|
+
BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
|
78
|
+
|
79
|
+
# consts. VCONST is a vectorized const
|
80
|
+
VCONST = auto(); CONST = auto() # noqa: E702
|
81
|
+
|
82
|
+
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
83
|
+
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
84
|
+
|
85
|
+
class GroupOp:
|
86
|
+
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
|
87
|
+
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
|
88
|
+
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
89
|
+
Ternary = {Ops.WHERE, Ops.MULACC}
|
90
|
+
ALU = set.union(Unary, Binary, Ternary)
|
91
|
+
|
92
|
+
# TODO: is BITCAST always Elementwise if it's shape changing?
|
93
|
+
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
94
|
+
|
95
|
+
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
96
|
+
|
97
|
+
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
98
|
+
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
99
|
+
|
100
|
+
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
101
|
+
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
|
102
|
+
|
103
|
+
# BinaryOps that can be flipped
|
104
|
+
Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}
|
105
|
+
|
106
|
+
# BinaryOps where f(f(a,b),c) = f(a,f(b,c))
|
107
|
+
Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
|
108
|
+
|
109
|
+
# BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
|
110
|
+
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
|
111
|
+
|
112
|
+
# do not preserve f(0) = 0
|
113
|
+
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
|
114
|
+
|
115
|
+
Meta = {Ops.COPY, Ops.BUFFER_VIEW}
|
116
|
+
|
117
|
+
All = set(Ops)
|
@@ -1,7 +1,9 @@
|
|
1
|
-
import
|
2
|
-
|
3
|
-
from tinygrad.
|
4
|
-
from tinygrad.
|
1
|
+
from typing import Callable
|
2
|
+
import math, functools
|
3
|
+
from tinygrad.dtype import dtypes, DType, promo_lattice
|
4
|
+
from tinygrad.device import is_dtype_supported
|
5
|
+
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
|
6
|
+
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
|
5
7
|
|
6
8
|
TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
7
9
|
|
@@ -10,9 +12,9 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
|
10
12
|
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
11
13
|
|
12
14
|
# *** helper functions for bit manipulation ***
|
13
|
-
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
|
14
|
-
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
|
15
|
-
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
|
15
|
+
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
16
|
+
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
|
17
|
+
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
16
18
|
|
17
19
|
# **** utils ****
|
18
20
|
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
|
@@ -20,41 +22,41 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
|
20
22
|
|
21
23
|
def rintk(d:UOp) -> UOp:
|
22
24
|
"""round d:float to int away from 0"""
|
23
|
-
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
25
|
+
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
|
24
26
|
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
25
27
|
|
26
28
|
def pow2if(q:UOp, float_dtype:DType):
|
27
29
|
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
28
|
-
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
|
30
|
+
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype.scalar()}[q.dtype.scalar()].vec(q.dtype.vcount)
|
29
31
|
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
|
30
32
|
|
31
33
|
def ilogb2k(d:UOp) -> UOp:
|
32
34
|
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
33
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
34
|
-
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
|
35
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
36
|
+
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
|
35
37
|
# -1 <= ilog2bk(d) <= 128
|
36
38
|
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
|
37
39
|
|
38
40
|
def ldexp3k(d:UOp, e:UOp) -> UOp:
|
39
41
|
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
|
40
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
41
|
-
|
42
|
-
m1 = d.bitcast(
|
43
|
-
m2 = shl(e.cast(
|
42
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
43
|
+
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
|
44
|
+
m1 = d.bitcast(dtype)
|
45
|
+
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
|
44
46
|
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
|
45
47
|
|
46
48
|
def ldexp2k(d:UOp, e:UOp) -> UOp:
|
47
49
|
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
|
48
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
50
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
|
49
51
|
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
50
52
|
|
51
53
|
def frexp(v:UOp) -> tuple[UOp, UOp]:
|
52
54
|
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
53
|
-
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
55
|
+
assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
54
56
|
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
55
|
-
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
|
56
|
-
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
|
57
|
-
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
|
57
|
+
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
|
58
|
+
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
|
59
|
+
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
|
58
60
|
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
|
59
61
|
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
|
60
62
|
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
|
@@ -70,12 +72,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
70
72
|
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
|
71
73
|
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
|
72
74
|
"""
|
73
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
75
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
74
76
|
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
|
75
77
|
# 190 bits of 2/pi for Payne-Hanek style argument reduction
|
76
78
|
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
|
77
79
|
|
78
|
-
intermediate_dtype = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
|
80
|
+
intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
|
79
81
|
|
80
82
|
f, e = frexp(d)
|
81
83
|
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
|
@@ -89,10 +91,10 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
89
91
|
if count+offset < len(two_over_pi_f) - 1:
|
90
92
|
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
91
93
|
return an
|
92
|
-
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
93
|
-
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
94
|
+
def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
95
|
+
def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
94
96
|
|
95
|
-
a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
|
97
|
+
a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
|
96
98
|
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
97
99
|
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
98
100
|
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
|
@@ -119,7 +121,7 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
119
121
|
"""
|
120
122
|
def _reduce_d(x:UOp, q:UOp):
|
121
123
|
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
|
122
|
-
if x.dtype == dtypes.float64:
|
124
|
+
if x.dtype.scalar() == dtypes.float64:
|
123
125
|
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
|
124
126
|
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
|
125
127
|
d = qdh * -PI_A + x
|
@@ -129,7 +131,7 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
129
131
|
d = qdh * -PI_C + d
|
130
132
|
d = q * -PI_C + d
|
131
133
|
d = (qdh + q) * -PI_D + d
|
132
|
-
elif x.dtype == dtypes.float16:
|
134
|
+
elif x.dtype.scalar() == dtypes.float16:
|
133
135
|
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
134
136
|
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
135
137
|
else:
|
@@ -142,11 +144,11 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
|
|
142
144
|
|
143
145
|
m_1_pi = 0.318309886183790671537767526745028724
|
144
146
|
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
|
145
|
-
quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
|
147
|
+
quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base.scalar() == dtypes.float64 else rintk(d * m_1_pi)
|
146
148
|
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
|
147
149
|
|
148
150
|
# *** approximate sine on small angle. ***
|
149
|
-
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
|
151
|
+
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype.scalar() == dtypes.float64 else polyN(d*d, coeff32))
|
150
152
|
# approximate sine on [-pi/2, pi/2]
|
151
153
|
def sin_poly(d:UOp) -> UOp:
|
152
154
|
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
|
@@ -172,7 +174,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
|
172
174
|
- fast=True assumes x <= switch_over.
|
173
175
|
- switch_over is the threshold for switching to payne_hanek_reduction.
|
174
176
|
"""
|
175
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
177
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
176
178
|
# mask +-inf/nan as zero
|
177
179
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
178
180
|
# x_sign = sign(x)
|
@@ -194,20 +196,20 @@ def xexp2(d:UOp) -> UOp:
|
|
194
196
|
Implements a 1.0 ULP approximation for Ops.EXP2
|
195
197
|
- Paper: https://arxiv.org/pdf/2001.09258
|
196
198
|
"""
|
197
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
199
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
198
200
|
# mask +=inf/nan as zero.
|
199
201
|
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
200
202
|
q = rintk(x)
|
201
203
|
# s = d - round(d)
|
202
204
|
s = x - q.cast(x.dtype)
|
203
205
|
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
204
|
-
if d.dtype == dtypes.float64:
|
206
|
+
if d.dtype.scalar() == dtypes.float64:
|
205
207
|
u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
|
206
208
|
0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
207
209
|
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
208
210
|
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
|
209
211
|
u = ldexp2k(u, q) # u*2^q
|
210
|
-
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
|
212
|
+
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
|
211
213
|
# Replace x >= upper with +inf
|
212
214
|
u = (d >= upper).where(d.const_like(math.inf), u)
|
213
215
|
# Replace x < lower with zero.
|
@@ -220,10 +222,10 @@ def xlog2(d:UOp) -> UOp:
|
|
220
222
|
Implements a 1.0 ULP approximation for Ops.LOG2
|
221
223
|
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
222
224
|
"""
|
223
|
-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
225
|
+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
224
226
|
# TODO: float16 denormal need float32 to achieve precision
|
225
|
-
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
226
|
-
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
227
|
+
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
228
|
+
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
|
227
229
|
is_denormal = d<FLT_MIN
|
228
230
|
a = is_denormal.where(d * (2 ** 64), d)
|
229
231
|
|
@@ -233,7 +235,7 @@ def xlog2(d:UOp) -> UOp:
|
|
233
235
|
|
234
236
|
x = (m - 1.0) / (m + 1.0)
|
235
237
|
x2 = x * x
|
236
|
-
if d.dtype == dtypes.float64:
|
238
|
+
if d.dtype.scalar() == dtypes.float64:
|
237
239
|
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
238
240
|
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
239
241
|
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
@@ -248,7 +250,7 @@ def xlog2(d:UOp) -> UOp:
|
|
248
250
|
r = (d<-0.0).where(r.const_like(math.nan), r)
|
249
251
|
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
250
252
|
# log2_zero = the value of unmasked xlog2(0.0).
|
251
|
-
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|
253
|
+
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
|
252
254
|
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
253
255
|
# log2(NaN) = NaN
|
254
256
|
r = d.ne(d).where(r.const_like(math.nan), r)
|
@@ -264,3 +266,88 @@ def xpow(base:UOp, exponent:UOp) -> UOp:
|
|
264
266
|
(exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
|
265
267
|
# fix 0 ** 0 = 1
|
266
268
|
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
|
269
|
+
|
270
|
+
# *** integer division ***
|
271
|
+
|
272
|
+
@functools.lru_cache(None)
|
273
|
+
def magicgu(vmax:int, d:int) -> tuple[int,int]:
|
274
|
+
# calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
|
275
|
+
nc = (vmax+1)//(d) * d - 1
|
276
|
+
nbits = vmax.bit_length()
|
277
|
+
for s in range(0, 2*nbits + 1):
|
278
|
+
if 2**s > nc*(d - 1 - (2**s - 1) % d):
|
279
|
+
m = (2**s + d - 1 - (2**s - 1) % d)//d
|
280
|
+
return m, s
|
281
|
+
assert False
|
282
|
+
|
283
|
+
def fast_idiv(device: str, x: UOp, d: int) -> UOp|None:
|
284
|
+
# If d is a power of two this is not valid for signed ints!
|
285
|
+
is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
|
286
|
+
assert d>0, "Sign should have been taken out of divisor"
|
287
|
+
vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
|
288
|
+
m,s = magicgu(max(vmax, abs(vmin)), d)
|
289
|
+
if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
|
290
|
+
return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
|
291
|
+
# promo_lattice needs to return an unsigned type if the type is unsigned
|
292
|
+
if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
|
293
|
+
if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
|
294
|
+
return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
|
295
|
+
return None
|
296
|
+
|
297
|
+
# ***** threefry *****
|
298
|
+
|
299
|
+
def threefry2x32(x: UOp, key: UOp):
|
300
|
+
# split x and key from uint64 to two uint32
|
301
|
+
x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
302
|
+
key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
|
303
|
+
|
304
|
+
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
305
|
+
ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
|
306
|
+
xr:list[UOp] = [x0 + ks[-1], x1 + ks[0]]
|
307
|
+
for i in range(5):
|
308
|
+
for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
|
309
|
+
xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
|
310
|
+
|
311
|
+
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
|
312
|
+
|
313
|
+
# ***** decomposition patterns *****
|
314
|
+
|
315
|
+
powers_of_two = {2**i:i for i in range(64)}
|
316
|
+
@functools.cache
|
317
|
+
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False):
|
318
|
+
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
319
|
+
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
320
|
+
# no real hardware supports THREEFRY, but NullRenderer does
|
321
|
+
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
322
|
+
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
323
|
+
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
324
|
+
# rewrite SQRT to xpow 0.5
|
325
|
+
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
326
|
+
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
327
|
+
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
328
|
+
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
329
|
+
if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
330
|
+
if Ops.SHR in ops:
|
331
|
+
# no reason to check x<0 for uints
|
332
|
+
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
|
333
|
+
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
|
334
|
+
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
|
335
|
+
if not DISABLE_FAST_IDIV:
|
336
|
+
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
|
337
|
+
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
|
338
|
+
if Ops.NEG in ops:
|
339
|
+
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
340
|
+
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
341
|
+
if Ops.CMPLT in ops:
|
342
|
+
# These are late rewrites because simplex expects equalities to be a certain format
|
343
|
+
pat += [
|
344
|
+
((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1<x),
|
345
|
+
((UPat.cvar("c", dtypes.sints) < UPat.var("x", dtypes.sints)).logical_not(), lambda x,c: x<c+1),
|
346
|
+
(UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*UPat.cvar("c"), lambda x,y,c: y*(-c)<x),
|
347
|
+
(UPat.var("x", dtypes.sints)*-1 < UPat.cvar("c"), lambda x,c:-c<x),
|
348
|
+
((UPat.cvar("c1",vec=False)<UPat.var("x", dtypes.sints)) & (UPat.var("x", dtypes.sints)<UPat.cvar("c2",vec=False)),
|
349
|
+
lambda x,c1,c2: x.eq(c1+1) if c1.arg+1==c2.arg-1 else None), # (c-1)<x & x<(c+1) -> x==c
|
350
|
+
]
|
351
|
+
if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))]
|
352
|
+
if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
|
353
|
+
return PatternMatcher(pat)
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from tinygrad.uop import Ops
|
2
|
+
from tinygrad.helpers import T
|
3
|
+
from tinygrad.dtype import dtypes
|
4
|
+
|
5
|
+
class MathTrait:
|
6
|
+
# required to implement
|
7
|
+
def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError
|
8
|
+
def const_like(self:T, b) -> T: raise NotImplementedError
|
9
|
+
|
10
|
+
# great functions you get!
|
11
|
+
def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
|
12
|
+
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
13
|
+
def logical_not(self): return self.ne(True)
|
14
|
+
def neg(self):
|
15
|
+
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
16
|
+
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
17
|
+
def _check_dtype(self):
|
18
|
+
if (dtype:=getattr(self, 'dtype')) is not None:
|
19
|
+
if isinstance(dtype, tuple): dtype = dtype[0]
|
20
|
+
if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)): raise RuntimeError(f"{dtype} is not supported")
|
21
|
+
def add(self, x, reverse=False):
|
22
|
+
"""
|
23
|
+
Adds `self` and `x`.
|
24
|
+
Equivalent to `self + x`.
|
25
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
26
|
+
```python exec="true" source="above" session="tensor" result="python"
|
27
|
+
Tensor.manual_seed(42)
|
28
|
+
t = Tensor.randn(4)
|
29
|
+
print(t.numpy())
|
30
|
+
```
|
31
|
+
```python exec="true" source="above" session="tensor" result="python"
|
32
|
+
print(t.add(20).numpy())
|
33
|
+
```
|
34
|
+
```python exec="true" source="above" session="tensor" result="python"
|
35
|
+
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
36
|
+
```
|
37
|
+
"""
|
38
|
+
return self._binop(Ops.ADD, x, reverse)
|
39
|
+
def mul(self, x, reverse=False):
|
40
|
+
"""
|
41
|
+
Multiplies `self` and `x`.
|
42
|
+
Equivalent to `self * x`.
|
43
|
+
Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
|
44
|
+
|
45
|
+
```python exec="true" source="above" session="tensor" result="python"
|
46
|
+
Tensor.manual_seed(42)
|
47
|
+
t = Tensor.randn(4)
|
48
|
+
print(t.numpy())
|
49
|
+
```
|
50
|
+
```python exec="true" source="above" session="tensor" result="python"
|
51
|
+
print(t.mul(3).numpy())
|
52
|
+
```
|
53
|
+
```python exec="true" source="above" session="tensor" result="python"
|
54
|
+
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
55
|
+
```
|
56
|
+
"""
|
57
|
+
return self._binop(Ops.MUL, x, reverse)
|
58
|
+
def bitwise_and(self, x, reverse=False):
|
59
|
+
"""
|
60
|
+
Computes the bitwise AND of `self` and `x`.
|
61
|
+
Equivalent to `self & x`.
|
62
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
63
|
+
```python exec="true" source="above" session="tensor" result="python"
|
64
|
+
print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
|
65
|
+
```
|
66
|
+
```python exec="true" source="above" session="tensor" result="python"
|
67
|
+
print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
|
68
|
+
```
|
69
|
+
"""
|
70
|
+
self._check_dtype()
|
71
|
+
return self._binop(Ops.AND, x, reverse)
|
72
|
+
def bitwise_or(self, x, reverse=False):
|
73
|
+
"""
|
74
|
+
Computes the bitwise OR of `self` and `x`.
|
75
|
+
Equivalent to `self | x`.
|
76
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
77
|
+
```python exec="true" source="above" session="tensor" result="python"
|
78
|
+
print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
|
79
|
+
```
|
80
|
+
```python exec="true" source="above" session="tensor" result="python"
|
81
|
+
print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
|
82
|
+
```
|
83
|
+
"""
|
84
|
+
self._check_dtype()
|
85
|
+
return self._binop(Ops.OR, x, reverse)
|
86
|
+
def bitwise_xor(self, x, reverse=False):
|
87
|
+
"""
|
88
|
+
Computes bitwise xor of `self` and `x`.
|
89
|
+
Equivalent to `self ^ x`.
|
90
|
+
Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
|
91
|
+
|
92
|
+
```python exec="true" source="above" session="tensor" result="python"
|
93
|
+
print(Tensor([-1, -2, 3]).bitwise_xor(Tensor([1, 0, 3])).numpy())
|
94
|
+
```
|
95
|
+
```python exec="true" source="above" session="tensor" result="python"
|
96
|
+
print(Tensor([True, True, False, False]).bitwise_xor(Tensor([True, False, True, False])).numpy())
|
97
|
+
```
|
98
|
+
"""
|
99
|
+
self._check_dtype()
|
100
|
+
return self._binop(Ops.XOR, x, reverse)
|
101
|
+
def idiv(self, x, reverse=False):
|
102
|
+
"""
|
103
|
+
Divides `self` by `x`.
|
104
|
+
Equivalent to `self // x`.
|
105
|
+
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
106
|
+
`idiv` performs integer division (truncate towards zero).
|
107
|
+
|
108
|
+
```python exec="true" source="above" session="tensor" result="python"
|
109
|
+
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
110
|
+
```
|
111
|
+
"""
|
112
|
+
return self._binop(Ops.IDIV, x, reverse)
|
113
|
+
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
114
|
+
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
115
|
+
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
116
|
+
|
117
|
+
def __neg__(self): return self.neg()
|
118
|
+
|
119
|
+
def __add__(self, x): return self.add(x)
|
120
|
+
def __sub__(self, x): return self.sub(x)
|
121
|
+
def __mul__(self, x): return self.mul(x)
|
122
|
+
def __truediv__(self, x): return self.div(x)
|
123
|
+
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
124
|
+
def __mod__(self, x): return self.mod(x)
|
125
|
+
def __and__(self, x): return self.bitwise_and(x)
|
126
|
+
def __or__(self, x): return self.bitwise_or(x)
|
127
|
+
def __xor__(self, x): return self.bitwise_xor(x)
|
128
|
+
|
129
|
+
def __radd__(self, x): return self.add(x, True)
|
130
|
+
def __rsub__(self, x): return self.sub(x, True)
|
131
|
+
def __rmul__(self, x): return self.mul(x, True)
|
132
|
+
def __rtruediv__(self, x): return self.div(x, True)
|
133
|
+
def __rfloordiv__(self, x): return self.idiv(x, True)
|
134
|
+
def __rand__(self, x): return self.bitwise_and(x, True)
|
135
|
+
def __ror__(self, x): return self.bitwise_or(x, True)
|
136
|
+
def __rxor__(self, x): return self.bitwise_xor(x, True)
|
137
|
+
def __rmod__(self, x): return self.mod(x, True)
|
138
|
+
|
139
|
+
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
140
|
+
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
141
|
+
def __ge__(self, x): return (self < x).logical_not()
|
142
|
+
def __le__(self, x): return (self > x).logical_not()
|
143
|
+
|
144
|
+
def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
|
145
|
+
def eq(self, x): return self.ne(x).logical_not()
|
146
|
+
def __ne__(self, x): return self.ne(x)
|
147
|
+
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
148
|
+
|
149
|
+
def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
|
150
|
+
def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
|
151
|
+
def __lshift__(self, x): return self.lshift(x)
|
152
|
+
def __rshift__(self, x): return self.rshift(x)
|
153
|
+
def __rlshift__(self, x): return self.lshift(x, True)
|
154
|
+
def __rrshift__(self, x): return self.rshift(x, True)
|
155
|
+
|
156
|
+
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
157
|
+
def minimum(self, x): return -(-self).maximum(-x)
|
158
|
+
def where(self, x, y):
|
159
|
+
if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y))
|
160
|
+
if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y)
|
161
|
+
raise RuntimeError("where needs at least one UOp arg")
|
162
|
+
def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
|
163
|
+
def reciprocal(self): return self.alu(Ops.RECIP)
|
164
|
+
def trunc(self): return self.alu(Ops.TRUNC)
|
165
|
+
def sqrt(self): return self.alu(Ops.SQRT)
|
166
|
+
def sin(self): return self.alu(Ops.SIN)
|
167
|
+
def log2(self): return self.alu(Ops.LOG2)
|
168
|
+
def exp2(self): return self.alu(Ops.EXP2)
|
169
|
+
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
|