tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ from enum import auto, IntEnum, Enum
2
+
3
+ # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
4
+ class FastEnum(IntEnum):
5
+ def __str__(self): return Enum.__str__(self)
6
+ @staticmethod
7
+ def _generate_next_value_(_, __, ___, last_values): return 1 + max([0, *last_values, *[max(c) for c in FastEnum.__subclasses__()]])
8
+
9
+ # the order of these Ops controls the order of the toposort
10
+ class Ops(FastEnum):
11
+ # uops that aren't rendered
12
+ NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto(); REWRITE_ERROR = auto() # noqa: E702
13
+
14
+ # track children
15
+ CHILD = auto(); CHILDREN = auto() # noqa: E702
16
+
17
+ # buffer ops
18
+ COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
19
+
20
+ # create buffer
21
+ BUFFERIZE = auto()
22
+
23
+ # ops that adjust the behavior of the scheduler
24
+ CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
25
+
26
+ # blocks in linearizer (only used there)
27
+ BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
28
+
29
+ # movement ops! these only exist in the tensor graph
30
+ RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
31
+ MULTI = auto() # MULTI is really a movement op
32
+
33
+ # view is what all movement ops become
34
+ VIEW = auto()
35
+
36
+ # TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
37
+ VALID = auto()
38
+
39
+ # TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
40
+ DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
41
+
42
+ # this is for symbolic shapes
43
+ DEFINE_VAR = auto(); BIND = auto() # noqa: E702
44
+
45
+ # this is a RANGE for GPU dimensions, similar to symbolic shapes but not exactly
46
+ SPECIAL = auto()
47
+
48
+ # reduce
49
+ REDUCE_AXIS = auto(); REDUCE = auto(); ALLREDUCE = auto() # noqa: E702
50
+
51
+ # optimization helper ops
52
+ UNROLL = auto(); CONTRACT = auto(); GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
53
+
54
+ # UnaryOps
55
+ CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto(); TRUNC = auto() # noqa: E702
56
+
57
+ # load/store before math
58
+ LOAD = auto(); STORE = auto() # noqa: E702
59
+ ASSIGN = auto() # TODO: ASSIGN is STORE, remove ASSIGN
60
+
61
+ # tensor core math op, not elementwise
62
+ WMMA = auto()
63
+
64
+ # INDEX is a BinaryOp similar to ADD, but it operates on pointers
65
+ INDEX = auto()
66
+
67
+ # BinaryOps
68
+ ADD = auto(); MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); MAX = auto(); MOD = auto() # noqa: E702
69
+ CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() # noqa: E702
70
+ XOR = auto(); OR = auto(); AND = auto() # noqa: E702
71
+ THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
72
+
73
+ # TernaryOps
74
+ WHERE = auto(); MULACC = auto() # noqa: E702
75
+
76
+ # control flow ops
77
+ BARRIER = auto(); RANGE = auto(); IF = auto(); ENDRANGE = auto(); ENDIF = auto() # noqa: E702
78
+
79
+ # consts. VCONST is a vectorized const
80
+ VCONST = auto(); CONST = auto() # noqa: E702
81
+
82
+ # CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
83
+ CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
84
+
85
+ class GroupOp:
86
+ Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG, Ops.TRUNC}
87
+ Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
88
+ Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
89
+ Ternary = {Ops.WHERE, Ops.MULACC}
90
+ ALU = set.union(Unary, Binary, Ternary)
91
+
92
+ # TODO: is BITCAST always Elementwise if it's shape changing?
93
+ Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
94
+
95
+ Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
96
+
97
+ Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
98
+ Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
99
+
100
+ Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
101
+ Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
102
+
103
+ # BinaryOps that can be flipped
104
+ Commutative = {Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPNE, Ops.CMPEQ, Ops.XOR, Ops.AND, Ops.OR}
105
+
106
+ # BinaryOps where f(f(a,b),c) = f(a,f(b,c))
107
+ Associative = {Ops.ADD, Ops.MUL, Ops.AND, Ops.OR, Ops.MAX}
108
+
109
+ # BinaryOps that satisfy f(x,x)=x see https://en.wikipedia.org/wiki/Idempotence
110
+ Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
111
+
112
+ # do not preserve f(0) = 0
113
+ UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
114
+
115
+ Meta = {Ops.COPY, Ops.BUFFER_VIEW}
116
+
117
+ All = set(Ops)
@@ -1,7 +1,9 @@
1
- import math
2
- from tinygrad.dtype import dtypes, DType
3
- from tinygrad.helpers import polyN
4
- from tinygrad.ops import UOp
1
+ from typing import Callable
2
+ import math, functools
3
+ from tinygrad.dtype import dtypes, DType, promo_lattice
4
+ from tinygrad.device import is_dtype_supported
5
+ from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
6
+ from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
5
7
 
6
8
  TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
7
9
 
@@ -10,9 +12,9 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
10
12
  return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
11
13
 
12
14
  # *** helper functions for bit manipulation ***
13
- def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
14
- def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
15
- def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
15
+ def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
16
+ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
17
+ def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
16
18
 
17
19
  # **** utils ****
18
20
  def shr(x:UOp, y:int) -> UOp: return x // (2**y)
@@ -20,41 +22,41 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
20
22
 
21
23
  def rintk(d:UOp) -> UOp:
22
24
  """round d:float to int away from 0"""
23
- out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
25
+ out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
24
26
  return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
25
27
 
26
28
  def pow2if(q:UOp, float_dtype:DType):
27
29
  """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
28
- out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
30
+ out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype.scalar()}[q.dtype.scalar()].vec(q.dtype.vcount)
29
31
  return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
30
32
 
31
33
  def ilogb2k(d:UOp) -> UOp:
32
34
  """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
33
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
34
- dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
35
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
36
+ dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
35
37
  # -1 <= ilog2bk(d) <= 128
36
38
  return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
37
39
 
38
40
  def ldexp3k(d:UOp, e:UOp) -> UOp:
39
41
  """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
40
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
41
- cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
42
- m1 = d.bitcast(cast_map[d.dtype])
43
- m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
42
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
43
+ dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
44
+ m1 = d.bitcast(dtype)
45
+ m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
44
46
  return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
45
47
 
46
48
  def ldexp2k(d:UOp, e:UOp) -> UOp:
47
49
  """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
48
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
50
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
49
51
  return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
50
52
 
51
53
  def frexp(v:UOp) -> tuple[UOp, UOp]:
52
54
  """frexp(v) -> (mantissa, exponent) assuming v != 0"""
53
- assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
55
+ assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
54
56
  # m1 = masks for mantissa, m2 = masks to normalize the mantissa.
55
- m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
56
- m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
57
- bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
57
+ m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
58
+ m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
59
+ bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
58
60
  exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
59
61
  # Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
60
62
  mantissa = ((bits & m1) | m2).bitcast(v.dtype)
@@ -70,12 +72,12 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
70
72
  - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
71
73
  - `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
72
74
  """
73
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
75
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
74
76
  # https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
75
77
  # 190 bits of 2/pi for Payne-Hanek style argument reduction
76
78
  two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
77
79
 
78
- intermediate_dtype = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
80
+ intermediate_dtype = dtypes.float32.vec(d.dtype.count) if d.dtype.base.scalar() == dtypes.float16 else d.dtype
79
81
 
80
82
  f, e = frexp(d)
81
83
  ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
@@ -89,10 +91,10 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
89
91
  if count+offset < len(two_over_pi_f) - 1:
90
92
  an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
91
93
  return an
92
- def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
93
- def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
94
+ def _shl_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
95
+ def _shr_lazy(x:UOp, y:UOp): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
94
96
 
95
- a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
97
+ a = [_take(UOp.const(dtypes.uint32.vec(d.dtype.count), 0), i) for i in range(4)]
96
98
  # (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
97
99
  # Note: e >= 1 for all numbers d >= 1.0. assume e != 0
98
100
  hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
@@ -119,7 +121,7 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
119
121
  """
120
122
  def _reduce_d(x:UOp, q:UOp):
121
123
  # https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
122
- if x.dtype == dtypes.float64:
124
+ if x.dtype.scalar() == dtypes.float64:
123
125
  # https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
124
126
  PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
125
127
  d = qdh * -PI_A + x
@@ -129,7 +131,7 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
129
131
  d = qdh * -PI_C + d
130
132
  d = q * -PI_C + d
131
133
  d = (qdh + q) * -PI_D + d
132
- elif x.dtype == dtypes.float16:
134
+ elif x.dtype.scalar() == dtypes.float16:
133
135
  # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
134
136
  d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
135
137
  else:
@@ -142,11 +144,11 @@ def cody_waite_reduction(d:UOp) -> tuple[UOp, UOp]:
142
144
 
143
145
  m_1_pi = 0.318309886183790671537767526745028724
144
146
  qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
145
- quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
147
+ quadrant = rintk(d * m_1_pi -qdh) if d.dtype.base.scalar() == dtypes.float64 else rintk(d * m_1_pi)
146
148
  return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
147
149
 
148
150
  # *** approximate sine on small angle. ***
149
- def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
151
+ def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype.scalar() == dtypes.float64 else polyN(d*d, coeff32))
150
152
  # approximate sine on [-pi/2, pi/2]
151
153
  def sin_poly(d:UOp) -> UOp:
152
154
  return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
@@ -172,7 +174,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
172
174
  - fast=True assumes x <= switch_over.
173
175
  - switch_over is the threshold for switching to payne_hanek_reduction.
174
176
  """
175
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
177
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
176
178
  # mask +-inf/nan as zero
177
179
  x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
178
180
  # x_sign = sign(x)
@@ -194,20 +196,20 @@ def xexp2(d:UOp) -> UOp:
194
196
  Implements a 1.0 ULP approximation for Ops.EXP2
195
197
  - Paper: https://arxiv.org/pdf/2001.09258
196
198
  """
197
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
199
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
198
200
  # mask +=inf/nan as zero.
199
201
  x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
200
202
  q = rintk(x)
201
203
  # s = d - round(d)
202
204
  s = x - q.cast(x.dtype)
203
205
  # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
204
- if d.dtype == dtypes.float64:
206
+ if d.dtype.scalar() == dtypes.float64:
205
207
  u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
206
208
  0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
207
209
  0.6931471805599452862e+0, 0.1000000000000000000e+1])
208
210
  else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
209
211
  u = ldexp2k(u, q) # u*2^q
210
- upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
212
+ upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
211
213
  # Replace x >= upper with +inf
212
214
  u = (d >= upper).where(d.const_like(math.inf), u)
213
215
  # Replace x < lower with zero.
@@ -220,10 +222,10 @@ def xlog2(d:UOp) -> UOp:
220
222
  Implements a 1.0 ULP approximation for Ops.LOG2
221
223
  Paper: https://arxiv.org/pdf/2001.09258 5.5
222
224
  """
223
- assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
225
+ assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
224
226
  # TODO: float16 denormal need float32 to achieve precision
225
- if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
226
- FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
227
+ if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
228
+ FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
227
229
  is_denormal = d<FLT_MIN
228
230
  a = is_denormal.where(d * (2 ** 64), d)
229
231
 
@@ -233,7 +235,7 @@ def xlog2(d:UOp) -> UOp:
233
235
 
234
236
  x = (m - 1.0) / (m + 1.0)
235
237
  x2 = x * x
236
- if d.dtype == dtypes.float64:
238
+ if d.dtype.scalar() == dtypes.float64:
237
239
  t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
238
240
  0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
239
241
  s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
@@ -248,7 +250,7 @@ def xlog2(d:UOp) -> UOp:
248
250
  r = (d<-0.0).where(r.const_like(math.nan), r)
249
251
  # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
250
252
  # log2_zero = the value of unmasked xlog2(0.0).
251
- log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
253
+ log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
252
254
  r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
253
255
  # log2(NaN) = NaN
254
256
  r = d.ne(d).where(r.const_like(math.nan), r)
@@ -264,3 +266,88 @@ def xpow(base:UOp, exponent:UOp) -> UOp:
264
266
  (exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1)))
265
267
  # fix 0 ** 0 = 1
266
268
  return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))
269
+
270
+ # *** integer division ***
271
+
272
+ @functools.lru_cache(None)
273
+ def magicgu(vmax:int, d:int) -> tuple[int,int]:
274
+ # calculate m,s such that x//d == (x*m) >> s for all 0 <= x <= vmax, d>0; adapted from Hacker's Delight, Chapter 10
275
+ nc = (vmax+1)//(d) * d - 1
276
+ nbits = vmax.bit_length()
277
+ for s in range(0, 2*nbits + 1):
278
+ if 2**s > nc*(d - 1 - (2**s - 1) % d):
279
+ m = (2**s + d - 1 - (2**s - 1) % d)//d
280
+ return m, s
281
+ assert False
282
+
283
+ def fast_idiv(device: str, x: UOp, d: int) -> UOp|None:
284
+ # If d is a power of two this is not valid for signed ints!
285
+ is_unsigned = True if x.vmin>=0 or x.dtype in dtypes.uints else False
286
+ assert d>0, "Sign should have been taken out of divisor"
287
+ vmin,vmax = max(x.vmin, x.dtype.min), min(x.vmax, x.dtype.max)
288
+ m,s = magicgu(max(vmax, abs(vmin)), d)
289
+ if m*vmin >= dtypes.min(x.dtype) and m*vmax <= dtypes.max(x.dtype):
290
+ return ((x*m) >> s) if is_unsigned else ((x*m) >> s) + (x<0).where(x.ufix(1), 0)
291
+ # promo_lattice needs to return an unsigned type if the type is unsigned
292
+ if dtypes.is_int(next_dtype := promo_lattice[x.dtype][-1]) and is_dtype_supported(next_dtype, None if device=='' else device):
293
+ if m*vmin >= dtypes.min(next_dtype) and m*vmax <= dtypes.max(next_dtype):
294
+ return ((x.cast(next_dtype)*m) >> s).cast(x.dtype) if is_unsigned else ((x.cast(next_dtype)*m) >> s).cast(x.dtype) + (x<0).where(x.ufix(1), 0)
295
+ return None
296
+
297
+ # ***** threefry *****
298
+
299
+ def threefry2x32(x: UOp, key: UOp):
300
+ # split x and key from uint64 to two uint32
301
+ x0, x1 = (x & 0xffffffff).cast(dtypes.uint32), ((x // 2**32) & 0xffffffff).cast(dtypes.uint32)
302
+ key0, key1 = (key & 0xffffffff).cast(dtypes.uint32), ((key // 2**32) & 0xffffffff).cast(dtypes.uint32)
303
+
304
+ rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
305
+ ks = [key1, key0 ^ key1 ^ 0x1BD11BDA, key0]
306
+ xr:list[UOp] = [x0 + ks[-1], x1 + ks[0]]
307
+ for i in range(5):
308
+ for r in rotations[i % 2]: xr[0], xr[1] = (x0 := xr[0] + xr[1]), x0 ^ ((xr[1] * 2**r) + (xr[1] // 2**(32 - r)))
309
+ xr = [(xr[0] + ks[i % 3]), (xr[1] + ks[(i + 1) % 3] + i + 1)]
310
+
311
+ return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
312
+
313
+ # ***** decomposition patterns *****
314
+
315
+ powers_of_two = {2**i:i for i in range(64)}
316
+ @functools.cache
317
+ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental=False):
318
+ pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
319
+ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
320
+ # no real hardware supports THREEFRY, but NullRenderer does
321
+ if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
322
+ # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
323
+ if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
324
+ # rewrite SQRT to xpow 0.5
325
+ if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
326
+ # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
327
+ if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
328
+ # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
329
+ if Ops.SHL in ops: pat += [(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << v if (v:=powers_of_two.get(c.arg, 0)) else None)]
330
+ if Ops.SHR in ops:
331
+ # no reason to check x<0 for uints
332
+ pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
333
+ pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
334
+ c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
335
+ if not DISABLE_FAST_IDIV:
336
+ pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d"), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
337
+ pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
338
+ if Ops.NEG in ops:
339
+ pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
340
+ if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
341
+ if Ops.CMPLT in ops:
342
+ # These are late rewrites because simplex expects equalities to be a certain format
343
+ pat += [
344
+ ((UPat.var("x", dtypes.sints) < UPat.cvar("c", dtypes.sints)).logical_not(), lambda x,c: c-1<x),
345
+ ((UPat.cvar("c", dtypes.sints) < UPat.var("x", dtypes.sints)).logical_not(), lambda x,c: x<c+1),
346
+ (UPat.var("x", dtypes.sints)*-1 < UPat.var("y", dtypes.sints)*UPat.cvar("c"), lambda x,y,c: y*(-c)<x),
347
+ (UPat.var("x", dtypes.sints)*-1 < UPat.cvar("c"), lambda x,c:-c<x),
348
+ ((UPat.cvar("c1",vec=False)<UPat.var("x", dtypes.sints)) & (UPat.var("x", dtypes.sints)<UPat.cvar("c2",vec=False)),
349
+ lambda x,c1,c2: x.eq(c1+1) if c1.arg+1==c2.arg-1 else None), # (c-1)<x & x<(c+1) -> x==c
350
+ ]
351
+ if Ops.CMPEQ in ops: pat += [(UPat.var('x').ne(UPat.var('y')).logical_not(), lambda x,y: x.alu(Ops.CMPEQ, y))]
352
+ if Ops.MULACC in ops: pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))]
353
+ return PatternMatcher(pat)
@@ -0,0 +1,169 @@
1
+ from tinygrad.uop import Ops
2
+ from tinygrad.helpers import T
3
+ from tinygrad.dtype import dtypes
4
+
5
+ class MathTrait:
6
+ # required to implement
7
+ def alu(self:T, op:Ops, *src) -> T: raise NotImplementedError
8
+ def const_like(self:T, b) -> T: raise NotImplementedError
9
+
10
+ # great functions you get!
11
+ def ufix(self, x): return self.const_like(x) if not isinstance(x, MathTrait) else x
12
+ def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
13
+ def logical_not(self): return self.ne(True)
14
+ def neg(self):
15
+ if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
16
+ return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
17
+ def _check_dtype(self):
18
+ if (dtype:=getattr(self, 'dtype')) is not None:
19
+ if isinstance(dtype, tuple): dtype = dtype[0]
20
+ if not (dtypes.is_bool(dtype) or dtypes.is_int(dtype)): raise RuntimeError(f"{dtype} is not supported")
21
+ def add(self, x, reverse=False):
22
+ """
23
+ Adds `self` and `x`.
24
+ Equivalent to `self + x`.
25
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
26
+ ```python exec="true" source="above" session="tensor" result="python"
27
+ Tensor.manual_seed(42)
28
+ t = Tensor.randn(4)
29
+ print(t.numpy())
30
+ ```
31
+ ```python exec="true" source="above" session="tensor" result="python"
32
+ print(t.add(20).numpy())
33
+ ```
34
+ ```python exec="true" source="above" session="tensor" result="python"
35
+ print(t.add(Tensor([[2.0], [3.5]])).numpy())
36
+ ```
37
+ """
38
+ return self._binop(Ops.ADD, x, reverse)
39
+ def mul(self, x, reverse=False):
40
+ """
41
+ Multiplies `self` and `x`.
42
+ Equivalent to `self * x`.
43
+ Supports broadcasting to a common shape, type promotion, and integer, float, boolean inputs.
44
+
45
+ ```python exec="true" source="above" session="tensor" result="python"
46
+ Tensor.manual_seed(42)
47
+ t = Tensor.randn(4)
48
+ print(t.numpy())
49
+ ```
50
+ ```python exec="true" source="above" session="tensor" result="python"
51
+ print(t.mul(3).numpy())
52
+ ```
53
+ ```python exec="true" source="above" session="tensor" result="python"
54
+ print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
55
+ ```
56
+ """
57
+ return self._binop(Ops.MUL, x, reverse)
58
+ def bitwise_and(self, x, reverse=False):
59
+ """
60
+ Computes the bitwise AND of `self` and `x`.
61
+ Equivalent to `self & x`.
62
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
63
+ ```python exec="true" source="above" session="tensor" result="python"
64
+ print(Tensor([2, 5, 255]).bitwise_and(Tensor([3, 14, 16])).numpy())
65
+ ```
66
+ ```python exec="true" source="above" session="tensor" result="python"
67
+ print(Tensor([True, True, False, False]).bitwise_and(Tensor([True, False, True, False])).numpy())
68
+ ```
69
+ """
70
+ self._check_dtype()
71
+ return self._binop(Ops.AND, x, reverse)
72
+ def bitwise_or(self, x, reverse=False):
73
+ """
74
+ Computes the bitwise OR of `self` and `x`.
75
+ Equivalent to `self | x`.
76
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
77
+ ```python exec="true" source="above" session="tensor" result="python"
78
+ print(Tensor([2, 5, 255]).bitwise_or(Tensor([4, 4, 4])).numpy())
79
+ ```
80
+ ```python exec="true" source="above" session="tensor" result="python"
81
+ print(Tensor([True, True, False, False]).bitwise_or(Tensor([True, False, True, False])).numpy())
82
+ ```
83
+ """
84
+ self._check_dtype()
85
+ return self._binop(Ops.OR, x, reverse)
86
+ def bitwise_xor(self, x, reverse=False):
87
+ """
88
+ Computes bitwise xor of `self` and `x`.
89
+ Equivalent to `self ^ x`.
90
+ Supports broadcasting to a common shape, type promotion, and integer, boolean inputs.
91
+
92
+ ```python exec="true" source="above" session="tensor" result="python"
93
+ print(Tensor([-1, -2, 3]).bitwise_xor(Tensor([1, 0, 3])).numpy())
94
+ ```
95
+ ```python exec="true" source="above" session="tensor" result="python"
96
+ print(Tensor([True, True, False, False]).bitwise_xor(Tensor([True, False, True, False])).numpy())
97
+ ```
98
+ """
99
+ self._check_dtype()
100
+ return self._binop(Ops.XOR, x, reverse)
101
+ def idiv(self, x, reverse=False):
102
+ """
103
+ Divides `self` by `x`.
104
+ Equivalent to `self // x`.
105
+ Supports broadcasting to a common shape, type promotion, and integer inputs.
106
+ `idiv` performs integer division (truncate towards zero).
107
+
108
+ ```python exec="true" source="above" session="tensor" result="python"
109
+ print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
110
+ ```
111
+ """
112
+ return self._binop(Ops.IDIV, x, reverse)
113
+ def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
114
+ def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
115
+ def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
116
+
117
+ def __neg__(self): return self.neg()
118
+
119
+ def __add__(self, x): return self.add(x)
120
+ def __sub__(self, x): return self.sub(x)
121
+ def __mul__(self, x): return self.mul(x)
122
+ def __truediv__(self, x): return self.div(x)
123
+ def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
124
+ def __mod__(self, x): return self.mod(x)
125
+ def __and__(self, x): return self.bitwise_and(x)
126
+ def __or__(self, x): return self.bitwise_or(x)
127
+ def __xor__(self, x): return self.bitwise_xor(x)
128
+
129
+ def __radd__(self, x): return self.add(x, True)
130
+ def __rsub__(self, x): return self.sub(x, True)
131
+ def __rmul__(self, x): return self.mul(x, True)
132
+ def __rtruediv__(self, x): return self.div(x, True)
133
+ def __rfloordiv__(self, x): return self.idiv(x, True)
134
+ def __rand__(self, x): return self.bitwise_and(x, True)
135
+ def __ror__(self, x): return self.bitwise_or(x, True)
136
+ def __rxor__(self, x): return self.bitwise_xor(x, True)
137
+ def __rmod__(self, x): return self.mod(x, True)
138
+
139
+ def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
140
+ def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
141
+ def __ge__(self, x): return (self < x).logical_not()
142
+ def __le__(self, x): return (self > x).logical_not()
143
+
144
+ def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x))
145
+ def eq(self, x): return self.ne(x).logical_not()
146
+ def __ne__(self, x): return self.ne(x)
147
+ # NOTE: __eq__ isn't overridden, and means the same thing as is by default
148
+
149
+ def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse)
150
+ def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse)
151
+ def __lshift__(self, x): return self.lshift(x)
152
+ def __rshift__(self, x): return self.rshift(x)
153
+ def __rlshift__(self, x): return self.lshift(x, True)
154
+ def __rrshift__(self, x): return self.rshift(x, True)
155
+
156
+ def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
157
+ def minimum(self, x): return -(-self).maximum(-x)
158
+ def where(self, x, y):
159
+ if type(self) is type(x): return self.alu(Ops.WHERE, x, x.ufix(y))
160
+ if type(self) is type(y): return self.alu(Ops.WHERE, y.ufix(x), y)
161
+ raise RuntimeError("where needs at least one UOp arg")
162
+ def threefry(self, seed): return self.alu(Ops.THREEFRY, seed)
163
+ def reciprocal(self): return self.alu(Ops.RECIP)
164
+ def trunc(self): return self.alu(Ops.TRUNC)
165
+ def sqrt(self): return self.alu(Ops.SQRT)
166
+ def sin(self): return self.alu(Ops.SIN)
167
+ def log2(self): return self.alu(Ops.LOG2)
168
+ def exp2(self): return self.alu(Ops.EXP2)
169
+ def pow(self, x): return self.alu(Ops.POW, self.ufix(x))