tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import List, Set, Dict, Tuple
|
2
|
+
import functools, heapq
|
3
|
+
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
|
4
|
+
from tinygrad.dtype import dtypes
|
5
|
+
from tinygrad.helpers import DEBUG
|
6
|
+
|
7
|
+
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
|
8
|
+
if u in children: return srcs[u]
|
9
|
+
srcs[u] = {}
|
10
|
+
children[u] = []
|
11
|
+
for x in u.src:
|
12
|
+
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
|
13
|
+
if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
|
14
|
+
children[x].append(u)
|
15
|
+
in_degree[u] = len(u.src)
|
16
|
+
return srcs[u]
|
17
|
+
|
18
|
+
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
19
|
+
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
20
|
+
# filter nodes that don't link to a sink
|
21
|
+
# BFS toposort
|
22
|
+
children: Dict[UOp, List[UOp]] = {}
|
23
|
+
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
|
24
|
+
in_degree: Dict[UOp, int] = {}
|
25
|
+
get_children_dfs(sink, children, range_srcs, in_degree)
|
26
|
+
|
27
|
+
@functools.lru_cache(None)
|
28
|
+
def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
|
29
|
+
if x.op is Ops.SINK: return set()
|
30
|
+
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
31
|
+
|
32
|
+
# scope children impact the toposort and END* insertion
|
33
|
+
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
34
|
+
range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
|
35
|
+
|
36
|
+
# assign priorities
|
37
|
+
def get_priority(u:UOp):
|
38
|
+
priority = 0
|
39
|
+
# prefer ranges that depend on the least number of independent ranges
|
40
|
+
if u.op is Ops.RANGE and u.arg[1]:
|
41
|
+
priority += u.arg[0]
|
42
|
+
for p in range_phi[u]:
|
43
|
+
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
44
|
+
elif u.op is Ops.CONST:
|
45
|
+
# place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
|
46
|
+
priority -= 100000000000
|
47
|
+
else:
|
48
|
+
# prefer uops that are loop children
|
49
|
+
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
|
50
|
+
if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
|
51
|
+
return priority
|
52
|
+
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
|
53
|
+
|
54
|
+
# prevent priority inversion
|
55
|
+
@functools.lru_cache(None)
|
56
|
+
def fix_priority(u:UOp, lowest_priority):
|
57
|
+
if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
|
58
|
+
priorities[u] = min(priorities[u], lowest_priority)
|
59
|
+
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
|
60
|
+
for x in u.src: fix_priority(x, priorities[u])
|
61
|
+
fix_priority(sink, 0)
|
62
|
+
|
63
|
+
# NOTE: the compare should never make it all the way to u
|
64
|
+
queue:List[Tuple[int, Tuple, UOp]] = []
|
65
|
+
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
66
|
+
|
67
|
+
for u in children:
|
68
|
+
if in_degree[u] == 0: push(u)
|
69
|
+
|
70
|
+
scope_end: Dict[UOp, UOp] = {}
|
71
|
+
_uops: List[UOp] = []
|
72
|
+
while queue:
|
73
|
+
p,_,x = heapq.heappop(queue)
|
74
|
+
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
|
75
|
+
if x in scope_children: scope_end[x] = x
|
76
|
+
if x.op is Ops.DEFINE_ACC:
|
77
|
+
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
|
78
|
+
_uops.insert(idx, x)
|
79
|
+
else: _uops.append(x)
|
80
|
+
for u, ss in scope_children.items():
|
81
|
+
if x in ss:
|
82
|
+
ss.remove(x)
|
83
|
+
if len(ss) == 0: scope_end[u] = x
|
84
|
+
for u in children[x]:
|
85
|
+
in_degree[u] -= 1
|
86
|
+
if in_degree[u] == 0: push(u)
|
87
|
+
|
88
|
+
# end scopes in toposort order
|
89
|
+
for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
|
90
|
+
|
91
|
+
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
92
|
+
if not skip_check: type_verify(_uops)
|
93
|
+
|
94
|
+
# strip the SINK
|
95
|
+
return _uops[:-1]
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# the job of the lowerer is to do indexing
|
2
|
+
from __future__ import annotations
|
3
|
+
import functools, itertools, operator
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import List, Tuple, cast, Optional
|
6
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
7
|
+
from tinygrad.shape.view import variable_to_uop
|
8
|
+
from tinygrad.dtype import dtypes, PtrDType
|
9
|
+
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
10
|
+
from tinygrad.renderer import Renderer
|
11
|
+
from tinygrad.helpers import all_int, prod, partition, flatten
|
12
|
+
|
13
|
+
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
14
|
+
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
15
|
+
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
16
|
+
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
17
|
+
except ValueError: return None
|
18
|
+
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
19
|
+
|
20
|
+
# ***** indexing *****
|
21
|
+
|
22
|
+
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
23
|
+
# TODO: symbolic shape
|
24
|
+
if not all_int(dims): return dims
|
25
|
+
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
26
|
+
for i,m in enumerate(max_sizes):
|
27
|
+
if dims[i] * dims[i+1] <= m:
|
28
|
+
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
29
|
+
break
|
30
|
+
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
31
|
+
return dims
|
32
|
+
|
33
|
+
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
34
|
+
if reverse: dims = dims[::-1]
|
35
|
+
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
36
|
+
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
37
|
+
if limited != dims:
|
38
|
+
ret = []
|
39
|
+
# cast for mypy, get_contraction won't be None
|
40
|
+
for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
|
41
|
+
if len(contraction) == 1: ret.append(idx)
|
42
|
+
else:
|
43
|
+
for c in contraction:
|
44
|
+
ret.append(idx % dims[c])
|
45
|
+
idx //= dims[c]
|
46
|
+
return ret[::-1] if reverse else ret
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class IndexContext:
|
50
|
+
idxs: List[UOp]
|
51
|
+
ridxs: List[UOp]
|
52
|
+
acc_num: int = 0
|
53
|
+
|
54
|
+
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
55
|
+
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
|
56
|
+
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
57
|
+
full_shape = ast.full_shape
|
58
|
+
first_upcasted = len(full_shape)-ki.upcasted
|
59
|
+
first_output_st: ShapeTracker = ast.src[0].st_arg
|
60
|
+
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
61
|
+
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
|
62
|
+
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
63
|
+
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
64
|
+
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
|
65
|
+
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
66
|
+
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
67
|
+
global_dims = first_reduce-ki.local_dims
|
68
|
+
|
69
|
+
if opts.has_local:
|
70
|
+
if ki.dont_use_locals:
|
71
|
+
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
72
|
+
idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
73
|
+
else:
|
74
|
+
# define indexes for GPU-like execution
|
75
|
+
idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
76
|
+
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
77
|
+
else:
|
78
|
+
# all loops are RANGES
|
79
|
+
idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
80
|
+
for i,g in enumerate(full_shape[:first_reduce])]
|
81
|
+
|
82
|
+
# reduce loops
|
83
|
+
idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
84
|
+
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
85
|
+
|
86
|
+
# upcast loops
|
87
|
+
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
88
|
+
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
89
|
+
idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
90
|
+
|
91
|
+
# late indexes (group for reduce)
|
92
|
+
ridxs = idxs[:]
|
93
|
+
for a in range(first_reduce, first_reduce+group_for_reduces):
|
94
|
+
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
95
|
+
|
96
|
+
return IndexContext(idxs, ridxs)
|
97
|
+
|
98
|
+
# ***** lowering (given index) *****
|
99
|
+
|
100
|
+
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
101
|
+
# NOTE: always using ridxs is fine here
|
102
|
+
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
103
|
+
assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
|
104
|
+
alu_op: Ops = x.arg[0]
|
105
|
+
ret = x.src[0]
|
106
|
+
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
107
|
+
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
108
|
+
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
109
|
+
if not len(reduce_range): return ret
|
110
|
+
# create ACC and assign
|
111
|
+
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
112
|
+
ctx.acc_num += 1
|
113
|
+
return acc.assign(acc.alu(alu_op, ret))
|
114
|
+
|
115
|
+
def lower_load_store(ctx: IndexContext, x: UOp):
|
116
|
+
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
117
|
+
# TODO: check has_valid in UPat, not here
|
118
|
+
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
119
|
+
buf = x.src[0]
|
120
|
+
if x.op is Ops.LOAD:
|
121
|
+
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
122
|
+
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid if has_valid else None),) + barrier)
|
123
|
+
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
124
|
+
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
|
125
|
+
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
|
126
|
+
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
|
127
|
+
else: store_back = False
|
128
|
+
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
129
|
+
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
130
|
+
if (not cast(PtrDType, x.src[0].dtype).local) or store_back:
|
131
|
+
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
132
|
+
if oidx is not ridx: valid = valid * oidx.eq(0)
|
133
|
+
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
134
|
+
return UOp(Ops.STORE, dtypes.void, (buf.index(idx, valid if has_valid else None), x.src[2]))
|
135
|
+
|
136
|
+
pm_lowerer = PatternMatcher([
|
137
|
+
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
138
|
+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
139
|
+
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
140
|
+
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
141
|
+
])
|
142
|
+
|
143
|
+
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
@@ -0,0 +1,257 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Tuple
|
3
|
+
from tinygrad.dtype import dtypes, DType
|
4
|
+
from tinygrad.helpers import polyN
|
5
|
+
from tinygrad.ops import UOp
|
6
|
+
|
7
|
+
TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
|
8
|
+
|
9
|
+
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
10
|
+
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
|
11
|
+
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
12
|
+
|
13
|
+
# *** helper functions for bit manipulation ***
|
14
|
+
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
|
15
|
+
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
|
16
|
+
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
|
17
|
+
|
18
|
+
# **** utils ****
|
19
|
+
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
|
20
|
+
def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
21
|
+
|
22
|
+
def rintk(d:UOp) -> UOp:
|
23
|
+
"""round d:float to int away from 0"""
|
24
|
+
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
25
|
+
return (d + d.lt(0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
26
|
+
|
27
|
+
def pow2if(q:UOp, float_dtype:DType):
|
28
|
+
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
29
|
+
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
|
30
|
+
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
|
31
|
+
|
32
|
+
def ilogb2k(d:UOp) -> UOp:
|
33
|
+
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
34
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
35
|
+
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
|
36
|
+
# -1 <= ilog2bk(d) <= 128
|
37
|
+
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
|
38
|
+
|
39
|
+
def ldexp3k(d:UOp, e:UOp) -> UOp:
|
40
|
+
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
|
41
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
42
|
+
cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
|
43
|
+
m1 = d.bitcast(cast_map[d.dtype])
|
44
|
+
m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
|
45
|
+
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
|
46
|
+
|
47
|
+
def ldexp2k(d:UOp, e:UOp) -> UOp:
|
48
|
+
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
|
49
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
50
|
+
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
51
|
+
|
52
|
+
def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
53
|
+
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
54
|
+
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
55
|
+
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
56
|
+
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
|
57
|
+
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
|
58
|
+
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
|
59
|
+
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
|
60
|
+
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
|
61
|
+
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
|
62
|
+
exp = exponent - exponent_bias(v.dtype) + 1
|
63
|
+
return mantissa, exp
|
64
|
+
|
65
|
+
# *** reduction algorithms for sine ***
|
66
|
+
def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
67
|
+
"""
|
68
|
+
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
|
69
|
+
39800.0 <= d <= +Inf
|
70
|
+
Returns a tuple of `(r, q)`:
|
71
|
+
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
|
72
|
+
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
|
73
|
+
"""
|
74
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
75
|
+
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
|
76
|
+
# 190 bits of 2/pi for Payne-Hanek style argument reduction
|
77
|
+
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
|
78
|
+
|
79
|
+
intermediate_dtype = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
|
80
|
+
|
81
|
+
f, e = frexp(d)
|
82
|
+
ia = (f.cast(intermediate_dtype) * 4.294967296e9).cast(dtypes.uint64)
|
83
|
+
# extract 96 relevant bits of 2/pi based on magnitude of argument
|
84
|
+
i = shr(e.cast(dtypes.uint64), 5)
|
85
|
+
e = e.cast(dtypes.int32) & 31
|
86
|
+
offset = 32 - e
|
87
|
+
|
88
|
+
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
89
|
+
"""an = two_over_pi_f[i+offset]"""
|
90
|
+
if count+offset < len(two_over_pi_f) - 1:
|
91
|
+
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
92
|
+
return an
|
93
|
+
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
94
|
+
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, d.dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
95
|
+
|
96
|
+
a = [_take(UOp.const(dtypes.uint32, 0), i) for i in range(4)]
|
97
|
+
# (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
98
|
+
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
99
|
+
hi = _shl_lazy(a[0], e) | _shr_lazy(a[1], offset)
|
100
|
+
mi = _shl_lazy(a[1], e) | _shr_lazy(a[2], offset)
|
101
|
+
lo = _shl_lazy(a[2], e) | _shr_lazy(a[3], offset)
|
102
|
+
|
103
|
+
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
|
104
|
+
# compute x * 2/pi
|
105
|
+
p = shl(_hp_mul(ia, hi), 32) + _hp_mul(ia, mi) + shr(_hp_mul(ia, lo), 32)
|
106
|
+
|
107
|
+
# round quotient to nearest
|
108
|
+
q = shr(p, 62).cast(dtypes.int32)
|
109
|
+
p = p & 0x3fffffffffffffff
|
110
|
+
r = (p.cast(intermediate_dtype) * (3.4061215800865545e-19)).cast(d.dtype)
|
111
|
+
|
112
|
+
# if fraction >= 0.5, r -= pi/2, q += 1
|
113
|
+
return f.lt(0.5).where(r, r - math.pi/2), f.lt(0.5).where(q, q + 1)
|
114
|
+
|
115
|
+
def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
116
|
+
"""
|
117
|
+
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
|
118
|
+
0 <= abs(d) <= 39800.0
|
119
|
+
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
|
120
|
+
"""
|
121
|
+
def _reduce_d(x:UOp, q:UOp):
|
122
|
+
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
|
123
|
+
if x.dtype == dtypes.float64:
|
124
|
+
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
|
125
|
+
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
|
126
|
+
d = qdh * -PI_A + x
|
127
|
+
d = q * -PI_A + d
|
128
|
+
d = qdh * -PI_B + d
|
129
|
+
d = q * -PI_B + d
|
130
|
+
d = qdh * -PI_C + d
|
131
|
+
d = q * -PI_C + d
|
132
|
+
d = (qdh + q) * -PI_D + d
|
133
|
+
elif x.dtype == dtypes.float16:
|
134
|
+
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
135
|
+
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
136
|
+
else:
|
137
|
+
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
|
138
|
+
d = q * -3.1414794921875 + x
|
139
|
+
d = q * -0.00011315941810607910156 + d
|
140
|
+
d = q * -1.9841872589410058936e-09 + d
|
141
|
+
d = q * -1.2154201256553420762e-10 + d
|
142
|
+
return d
|
143
|
+
|
144
|
+
m_1_pi = 0.318309886183790671537767526745028724
|
145
|
+
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
|
146
|
+
quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
|
147
|
+
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
|
148
|
+
|
149
|
+
# *** approximate sine on small angle. ***
|
150
|
+
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
|
151
|
+
# approximate sine on [-pi/2, pi/2]
|
152
|
+
def sin_poly(d:UOp) -> UOp:
|
153
|
+
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938, 1.0],
|
154
|
+
[-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
|
155
|
+
-2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
|
156
|
+
-0.166666666666666657414808, 1.0])
|
157
|
+
|
158
|
+
def _ifand(q:UOp, n:int): return (q & n).ne(0)
|
159
|
+
|
160
|
+
def sin_poly_small(d:UOp, q:UOp) -> UOp:
|
161
|
+
r = sin_poly(d)
|
162
|
+
return r * _ifand(q, 1).where(r.const_like(-1), r.const_like(1))
|
163
|
+
|
164
|
+
def sin_poly_large(d:UOp, q:UOp) -> UOp:
|
165
|
+
r = sin_poly(d + _ifand(q, 1).where(d.const_like(math.pi / 2), d.const_like(0)))
|
166
|
+
return r * _ifand(q, 2).where(r.const_like(-1), r.const_like(1))
|
167
|
+
|
168
|
+
# *** toplevel functions for xsin/xlog2/xexp2 ***
|
169
|
+
|
170
|
+
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
171
|
+
"""
|
172
|
+
Implements a 1.0 ULP approximation for Ops.SIN.
|
173
|
+
- fast=True assumes x <= switch_over.
|
174
|
+
- switch_over is the threshold for switching to payne_hanek_reduction.
|
175
|
+
"""
|
176
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
177
|
+
# mask +-inf/nan as zero
|
178
|
+
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
179
|
+
# x_sign = sign(x)
|
180
|
+
x_sign = x.ne(0).where(x.lt(0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
181
|
+
x_abs = x * x_sign
|
182
|
+
r, q = (cody_waite_reduction if fast else payne_hanek_reduction)(x_abs)
|
183
|
+
if fast: result = sin_poly_small(r, q)
|
184
|
+
else:
|
185
|
+
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
186
|
+
r_small, q_small = cody_waite_reduction(x_abs)
|
187
|
+
result = x_abs.lt(switch_over).where(sin_poly_small(r_small, q_small), sin_poly_large(r, q))
|
188
|
+
# adjusts the sign for abs(x)
|
189
|
+
result = result * x_sign
|
190
|
+
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
191
|
+
return _lazy_map_numbers(d, d.const_like(math.nan), d.const_like(math.nan), d.const_like(math.nan), result)
|
192
|
+
|
193
|
+
def xexp2(d:UOp) -> UOp:
|
194
|
+
"""
|
195
|
+
Implements a 1.0 ULP approximation for Ops.EXP2
|
196
|
+
- Paper: https://arxiv.org/pdf/2001.09258
|
197
|
+
"""
|
198
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
199
|
+
# mask +=inf/nan as zero.
|
200
|
+
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
201
|
+
q = rintk(x)
|
202
|
+
# s = d - round(d)
|
203
|
+
s = x - q.cast(x.dtype)
|
204
|
+
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
205
|
+
if d.dtype == dtypes.float64:
|
206
|
+
u = polyN(s, [0.4434359082926529454e-9, 0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4,
|
207
|
+
0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
208
|
+
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
209
|
+
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
|
210
|
+
u = ldexp2k(u, q) # u*2^q
|
211
|
+
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
|
212
|
+
# Replace x >= upper with +inf
|
213
|
+
u = d.ge(upper).where(d.const_like(math.inf), u)
|
214
|
+
# Replace x <= lower with zero.
|
215
|
+
u = d.lt(lower).where(d.const_like(0.0), u)
|
216
|
+
# exp2(NaN) = NaN
|
217
|
+
return d.ne(d).where(d.const_like(math.nan), u)
|
218
|
+
|
219
|
+
def xlog2(d:UOp) -> UOp:
|
220
|
+
"""
|
221
|
+
Implements a 1.0 ULP approximation for Ops.LOG2
|
222
|
+
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
223
|
+
"""
|
224
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
225
|
+
# TODO: float16 denormal need float32 to achieve precision
|
226
|
+
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
227
|
+
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
228
|
+
is_denormal = d.lt(FLT_MIN)
|
229
|
+
a = is_denormal.where(d * (2 ** 64), d)
|
230
|
+
|
231
|
+
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
232
|
+
m = ldexp3k(a, -e)
|
233
|
+
e = is_denormal.where(e - 64, e)
|
234
|
+
|
235
|
+
x = (m - 1.0) / (m + 1.0)
|
236
|
+
x2 = x * x
|
237
|
+
if d.dtype == dtypes.float64:
|
238
|
+
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
239
|
+
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
240
|
+
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
241
|
+
else:
|
242
|
+
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
|
243
|
+
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
|
244
|
+
r = t * (x * x2) + (s_hi + s_lo)
|
245
|
+
|
246
|
+
# log2(Inf) = Inf
|
247
|
+
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
248
|
+
# log2(x) = NaN for x < 0
|
249
|
+
r = d.lt(-0.0).where(r.const_like(math.nan), r)
|
250
|
+
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
251
|
+
# log2_zero = the value of unmasked xlog2(0.0).
|
252
|
+
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|
253
|
+
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
254
|
+
# log2(NaN) = NaN
|
255
|
+
r = d.ne(d).where(r.const_like(math.nan), r)
|
256
|
+
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
257
|
+
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
|