tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,95 @@
1
+ from typing import List, Set, Dict, Tuple
2
+ import functools, heapq
3
+ from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
4
+ from tinygrad.dtype import dtypes
5
+ from tinygrad.helpers import DEBUG
6
+
7
+ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
8
+ if u in children: return srcs[u]
9
+ srcs[u] = {}
10
+ children[u] = []
11
+ for x in u.src:
12
+ srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
13
+ if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
14
+ children[x].append(u)
15
+ in_degree[u] = len(u.src)
16
+ return srcs[u]
17
+
18
+ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
19
+ assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
20
+ # filter nodes that don't link to a sink
21
+ # BFS toposort
22
+ children: Dict[UOp, List[UOp]] = {}
23
+ range_srcs: Dict[UOp, Dict[UOp, None]] = {}
24
+ in_degree: Dict[UOp, int] = {}
25
+ get_children_dfs(sink, children, range_srcs, in_degree)
26
+
27
+ @functools.lru_cache(None)
28
+ def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
29
+ if x.op is Ops.SINK: return set()
30
+ return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
31
+
32
+ # scope children impact the toposort and END* insertion
33
+ scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
34
+ range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
35
+
36
+ # assign priorities
37
+ def get_priority(u:UOp):
38
+ priority = 0
39
+ # prefer ranges that depend on the least number of independent ranges
40
+ if u.op is Ops.RANGE and u.arg[1]:
41
+ priority += u.arg[0]
42
+ for p in range_phi[u]:
43
+ priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
44
+ elif u.op is Ops.CONST:
45
+ # place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
46
+ priority -= 100000000000
47
+ else:
48
+ # prefer uops that are loop children
49
+ priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
50
+ if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
51
+ return priority
52
+ priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
53
+
54
+ # prevent priority inversion
55
+ @functools.lru_cache(None)
56
+ def fix_priority(u:UOp, lowest_priority):
57
+ if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
58
+ priorities[u] = min(priorities[u], lowest_priority)
59
+ if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
60
+ for x in u.src: fix_priority(x, priorities[u])
61
+ fix_priority(sink, 0)
62
+
63
+ # NOTE: the compare should never make it all the way to u
64
+ queue:List[Tuple[int, Tuple, UOp]] = []
65
+ def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
66
+
67
+ for u in children:
68
+ if in_degree[u] == 0: push(u)
69
+
70
+ scope_end: Dict[UOp, UOp] = {}
71
+ _uops: List[UOp] = []
72
+ while queue:
73
+ p,_,x = heapq.heappop(queue)
74
+ if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
75
+ if x in scope_children: scope_end[x] = x
76
+ if x.op is Ops.DEFINE_ACC:
77
+ idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
78
+ _uops.insert(idx, x)
79
+ else: _uops.append(x)
80
+ for u, ss in scope_children.items():
81
+ if x in ss:
82
+ ss.remove(x)
83
+ if len(ss) == 0: scope_end[u] = x
84
+ for u in children[x]:
85
+ in_degree[u] -= 1
86
+ if in_degree[u] == 0: push(u)
87
+
88
+ # end scopes in toposort order
89
+ for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
90
+
91
+ # sanity checks (NOTE: these can cause things to be skipped in BEAM)
92
+ if not skip_check: type_verify(_uops)
93
+
94
+ # strip the SINK
95
+ return _uops[:-1]
@@ -0,0 +1,143 @@
1
+ # the job of the lowerer is to do indexing
2
+ from __future__ import annotations
3
+ import functools, itertools, operator
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, cast, Optional
6
+ from tinygrad.shape.shapetracker import ShapeTracker
7
+ from tinygrad.shape.view import variable_to_uop
8
+ from tinygrad.dtype import dtypes, PtrDType
9
+ from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
10
+ from tinygrad.renderer import Renderer
11
+ from tinygrad.helpers import all_int, prod, partition, flatten
12
+
13
+ # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
14
+ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
15
+ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
16
+ try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
17
+ except ValueError: return None
18
+ return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
19
+
20
+ # ***** indexing *****
21
+
22
+ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
23
+ # TODO: symbolic shape
24
+ if not all_int(dims): return dims
25
+ while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
26
+ for i,m in enumerate(max_sizes):
27
+ if dims[i] * dims[i+1] <= m:
28
+ dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
29
+ break
30
+ else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
31
+ return dims
32
+
33
+ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
34
+ if reverse: dims = dims[::-1]
35
+ limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
36
+ ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
37
+ if limited != dims:
38
+ ret = []
39
+ # cast for mypy, get_contraction won't be None
40
+ for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
41
+ if len(contraction) == 1: ret.append(idx)
42
+ else:
43
+ for c in contraction:
44
+ ret.append(idx % dims[c])
45
+ idx //= dims[c]
46
+ return ret[::-1] if reverse else ret
47
+
48
+ @dataclass
49
+ class IndexContext:
50
+ idxs: List[UOp]
51
+ ridxs: List[UOp]
52
+ acc_num: int = 0
53
+
54
+ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
55
+ ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
56
+ # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
57
+ full_shape = ast.full_shape
58
+ first_upcasted = len(full_shape)-ki.upcasted
59
+ first_output_st: ShapeTracker = ast.src[0].st_arg
60
+ # if there's no reduce, this is first_upcasted. assumes reduces are at the end
61
+ first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
62
+ local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
63
+ # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
64
+ group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
65
+ [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
66
+ first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
67
+ global_dims = first_reduce-ki.local_dims
68
+
69
+ if opts.has_local:
70
+ if ki.dont_use_locals:
71
+ assert ki.local_dims == 0, "can't use locals if there's no local dims"
72
+ idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
73
+ else:
74
+ # define indexes for GPU-like execution
75
+ idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
76
+ get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
77
+ else:
78
+ # all loops are RANGES
79
+ idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
80
+ for i,g in enumerate(full_shape[:first_reduce])]
81
+
82
+ # reduce loops
83
+ idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
84
+ for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
85
+
86
+ # upcast loops
87
+ for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
88
+ assert isinstance(g, int), "needs to be int to upcast/unroll"
89
+ idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
90
+
91
+ # late indexes (group for reduce)
92
+ ridxs = idxs[:]
93
+ for a in range(first_reduce, first_reduce+group_for_reduces):
94
+ ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
95
+
96
+ return IndexContext(idxs, ridxs)
97
+
98
+ # ***** lowering (given index) *****
99
+
100
+ def lower_reduce_axis(ctx: IndexContext, x: UOp):
101
+ # NOTE: always using ridxs is fine here
102
+ reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
103
+ assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
104
+ alu_op: Ops = x.arg[0]
105
+ ret = x.src[0]
106
+ if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
107
+ ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
108
+ ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
109
+ if not len(reduce_range): return ret
110
+ # create ACC and assign
111
+ acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
112
+ ctx.acc_num += 1
113
+ return acc.assign(acc.alu(alu_op, ret))
114
+
115
+ def lower_load_store(ctx: IndexContext, x: UOp):
116
+ idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
117
+ # TODO: check has_valid in UPat, not here
118
+ has_valid = valid.op is not Ops.CONST or valid.arg is not True
119
+ buf = x.src[0]
120
+ if x.op is Ops.LOAD:
121
+ barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
122
+ return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
123
+ # NOTE: only store the local reduceop in the threads that are actually doing the reduce
124
+ if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
125
+ reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
126
+ store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
127
+ else: store_back = False
128
+ # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
129
+ if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
130
+ if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
131
+ for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
132
+ if oidx is not ridx: valid = valid * oidx.eq(0)
133
+ has_valid = valid.op is not Ops.CONST or valid.arg is not True
134
+ return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
135
+
136
+ pm_lowerer = PatternMatcher([
137
+ (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
138
+ (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
139
+ # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
140
+ (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
141
+ ])
142
+
143
+ def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
@@ -0,0 +1,257 @@
1
+ import math
2
+ from typing import Tuple
3
+ from tinygrad.dtype import dtypes, DType
4
+ from tinygrad.helpers import polyN
5
+ from tinygrad.ops import UOp
6
+
7
+ TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
8
+
9
+ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
10
+ """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
11
+ return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
12
+
13
+ # *** helper functions for bit manipulation ***
14
+ def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
15
+ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
16
+ def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
17
+
18
+ # **** utils ****
19
+ def shr(x:UOp, y:int) -> UOp: return x // (2**y)
20
+ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
21
+
22
+ def rintk(d:UOp) -> UOp:
23
+ """round d:float to int away from 0"""
24
+ out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
25
+ return (d + d.lt(0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
26
+
27
+ def pow2if(q:UOp, float_dtype:DType):
28
+ """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
29
+ out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
30
+ return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
31
+
32
+ def ilogb2k(d:UOp) -> UOp:
33
+ """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
34
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
35
+ dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
36
+ # -1 <= ilog2bk(d) <= 128
37
+ return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
38
+
39
+ def ldexp3k(d:UOp, e:UOp) -> UOp:
40
+ """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
41
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
42
+ cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
43
+ m1 = d.bitcast(cast_map[d.dtype])
44
+ m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
45
+ return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
46
+
47
+ def ldexp2k(d:UOp, e:UOp) -> UOp:
48
+ """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
49
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
50
+ return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
51
+
52
+ def frexp(v:UOp) -> Tuple[UOp, UOp]:
53
+ """frexp(v) -> (mantissa, exponent) assuming v != 0"""
54
+ assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
55
+ # m1 = masks for mantissa, m2 = masks to normalize the mantissa.
56
+ m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
57
+ m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
58
+ bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
59
+ exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
60
+ # Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
61
+ mantissa = ((bits & m1) | m2).bitcast(v.dtype)
62
+ exp = exponent - exponent_bias(v.dtype) + 1
63
+ return mantissa, exp
64
+
65
+ # *** reduction algorithms for sine ***
66
+ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
67
+ """
68
+ Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
69
+ 39800.0 <= d <= +Inf
70
+ Returns a tuple of `(r, q)`:
71
+ - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
72
+ - `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
73
+ """
74
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
75
+ # https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
76
+ # 190 bits of 2/pi for Payne-Hanek style argument reduction
77
+ two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
78
+
79
+ intermediate_dtype = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
80
+
81
+ f, e = frexp(d)
82
+ ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
83
+ # extract 96 relevant bits of 2/pi based on magnitude of argument
84
+ i = shr(e.cast(dtypes.uint64), 5)
85
+ e = e.cast(dtypes.int32) & 31
86
+ offset = 32 - e
87
+
88
+ def _take(an:UOp, offset:int, count:int=0) -> UOp:
89
+ """an = two_over_pi_f[i+offset]"""
90
+ if count+offset < len(two_over_pi_f) - 1:
91
+ an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
92
+ return an
93
+ def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
94
+ def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
95
+
96
+ a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
97
+ # (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
98
+ # Note: e >= 1 for all numbers d >= 1.0. assume e != 0
99
+ hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
100
+ mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
101
+ lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
102
+
103
+ def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
104
+ # compute x * 2/pi
105
+ p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
106
+
107
+ # round quotient to nearest
108
+ q = shr(p, 62).cast(dtypes.int32)
109
+ p = p & 0x3fffffffffffffff
110
+ r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
111
+
112
+ # if fraction >= 0.5, r -= pi/2, q += 1
113
+ return f.lt(0.5).where(r, r - math.pi/2), f.lt(0.5).where(q, q + 1)
114
+
115
+ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
116
+ """
117
+ Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
118
+ 0 <= abs(d) <= 39800.0
119
+ Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
120
+ """
121
+ def _reduce_d(x:UOp, q:UOp):
122
+ # https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
123
+ if x.dtype == dtypes.float64:
124
+ # https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
125
+ PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
126
+ d = qdh * -PI_A + x
127
+ d = q * -PI_A + d
128
+ d = qdh * -PI_B + d
129
+ d = q * -PI_B + d
130
+ d = qdh * -PI_C + d
131
+ d = q * -PI_C + d
132
+ d = (qdh + q) * -PI_D + d
133
+ elif x.dtype == dtypes.float16:
134
+ # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
135
+ d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
136
+ else:
137
+ # https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
138
+ d = q * -3.1414794921875 + x
139
+ d = q * -0.00011315941810607910156 + d
140
+ d = q * -1.9841872589410058936e-09 + d
141
+ d = q * -1.2154201256553420762e-10 + d
142
+ return d
143
+
144
+ m_1_pi = 0.318309886183790671537767526745028724
145
+ qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
146
+ quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
147
+ return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
148
+
149
+ # *** approximate sine on small angle. ***
150
+ def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
151
+ # approximate sine on [-pi/2, pi/2]
152
+ def sin_poly(d:UOp) -> UOp:
153
+ return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
154
+ [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
155
+ -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
156
+ -0.166666666666666657414808, 1.0])
157
+
158
+ def _ifand(q:UOp, n:int): return (q & n).ne(0)
159
+
160
+ def sin_poly_small(d:UOp, q:UOp) -> UOp:
161
+ r = sin_poly(d)
162
+ return r * _ifand(q, 1).where(r.const_like(-1), r.const_like(1))
163
+
164
+ def sin_poly_large(d:UOp, q:UOp) -> UOp:
165
+ r = sin_poly(d + _ifand(q, 1).where(d.const_like(math.pi / 2), d.const_like(0)))
166
+ return r * _ifand(q, 2).where(r.const_like(-1), r.const_like(1))
167
+
168
+ # *** toplevel functions for xsin/xlog2/xexp2 ***
169
+
170
+ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
171
+ """
172
+ Implements a 1.0 ULP approximation for Ops.SIN.
173
+ - fast=True assumes x <= switch_over.
174
+ - switch_over is the threshold for switching to payne_hanek_reduction.
175
+ """
176
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
177
+ # mask +-inf/nan as zero
178
+ x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
179
+ # x_sign = sign(x)
180
+ x_sign = x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
181
+ x_abs = x * x_sign
182
+ r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
183
+ if fast: result = sin_poly_small(r, q)
184
+ else:
185
+ # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
186
+ r_small, q_small = cody_waite_reduction(x_abs)
187
+ result = x_abs.lt(switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
188
+ # adjusts the sign for abs(x)
189
+ result = result * x_sign
190
+ # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
191
+ return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result)
192
+
193
+ def xexp2(d:UOp) -> UOp:
194
+ """
195
+ Implements a 1.0 ULP approximation for Ops.EXP2
196
+ - Paper: https://arxiv.org/pdf/2001.09258
197
+ """
198
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
199
+ # mask +=inf/nan as zero.
200
+ x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
201
+ q = rintk(x)
202
+ # s = d - round(d)
203
+ s = x - q.cast(x.dtype)
204
+ # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
205
+ if d.dtype == dtypes.float64:
206
+ u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
207
+ 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
208
+ 0.6931471805599452862e+0, 0.1000000000000000000e+1])
209
+ else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
210
+ u = ldexp2k(u, q) # u*2^q
211
+ upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
212
+ # Replace x >= upper with +inf
213
+ u = d.ge(upper).where(d.const_like(math.inf), u)
214
+ # Replace x <= lower with zero.
215
+ u = d.lt(lower).where(d.const_like(0.0), u)
216
+ # exp2(NaN) = NaN
217
+ return d.ne(d).where(d.const_like(math.nan), u)
218
+
219
+ def xlog2(d:UOp) -> UOp:
220
+ """
221
+ Implements a 1.0 ULP approximation for Ops.LOG2
222
+ Paper: https://arxiv.org/pdf/2001.09258 5.5
223
+ """
224
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
225
+ # TODO: float16 denormal need float32 to achieve precision
226
+ if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
227
+ FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
228
+ is_denormal = d.lt(FLT_MIN)
229
+ a = is_denormal.where(d * (2 ** 64), d)
230
+
231
+ e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
232
+ m = ldexp3k(a, -e)
233
+ e = is_denormal.where(e - 64, e)
234
+
235
+ x = (m - 1.0) / (m + 1.0)
236
+ x2 = x * x
237
+ if d.dtype == dtypes.float64:
238
+ t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
239
+ 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
240
+ s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
241
+ else:
242
+ t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
243
+ s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
244
+ r = t * (x * x2) + (s_hi + s_lo)
245
+
246
+ # log2(Inf) = Inf
247
+ r = d.ne(math.inf).where(r, r.const_like(math.inf))
248
+ # log2(x) = NaN for x < 0
249
+ r = d.lt(-0.0).where(r.const_like(math.nan), r)
250
+ # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
251
+ # log2_zero = the value of unmasked xlog2(0.0).
252
+ log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
253
+ r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
254
+ # log2(NaN) = NaN
255
+ r = d.ne(d).where(r.const_like(math.nan), r)
256
+ # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
257
+ return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))