tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import List, Tuple, cast, Optional, Any, Dict
|
3
|
+
import functools
|
4
|
+
from tinygrad.shape.shapetracker import ShapeTracker, View
|
5
|
+
from tinygrad.shape.symbolic import sint
|
6
|
+
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
7
|
+
from tinygrad.ops import BufferOps, LazyOp, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer, BinaryOps
|
8
|
+
from tinygrad.codegen.uops import UOp, UOps
|
9
|
+
from tinygrad.renderer import Renderer
|
10
|
+
from tinygrad.helpers import getenv, all_int, get_contraction, prod, partition, flatten
|
11
|
+
|
12
|
+
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
13
|
+
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
14
|
+
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
15
|
+
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
|
16
|
+
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
17
|
+
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
18
|
+
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
19
|
+
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
20
|
+
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
|
21
|
+
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
|
22
|
+
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
23
|
+
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
24
|
+
|
25
|
+
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
26
|
+
# TODO: dtypes.realint
|
27
|
+
iexpr = variable_to_uop(view.offset)
|
28
|
+
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
29
|
+
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
|
30
|
+
if m is not None:
|
31
|
+
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
32
|
+
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
33
|
+
return iexpr, vexpr
|
34
|
+
|
35
|
+
# TODO: change this once UOps is ready to replace symbolic
|
36
|
+
def st_to_uops_graph(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
37
|
+
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
|
38
|
+
for view in reversed(st.views[0:-1]):
|
39
|
+
view = view.minify()
|
40
|
+
acc, idxs = 1, []
|
41
|
+
for _d in reversed(view.shape):
|
42
|
+
d = variable_to_uop(_d)
|
43
|
+
idxs.append((idx//acc)%d)
|
44
|
+
acc *= d
|
45
|
+
idx, valid = _uop_view(view, idxs[::-1], valid)
|
46
|
+
if isinstance(dtype, ImageDType):
|
47
|
+
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
|
48
|
+
return idx, valid
|
49
|
+
|
50
|
+
# TODO: this is the old one, delete when ready
|
51
|
+
def st_to_uops_symbolic(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
52
|
+
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
|
53
|
+
idx, valid = st.expr_idxs(fake_idxs)
|
54
|
+
ctx = dict(zip(fake_idxs, idxs))
|
55
|
+
uvalid = valid.render(render_ops, ctx)
|
56
|
+
if isinstance(dtype, ImageDType):
|
57
|
+
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
|
58
|
+
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
|
59
|
+
else:
|
60
|
+
uidx = idx.render(render_ops, ctx)
|
61
|
+
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
|
62
|
+
assert uvalid.dtype == dtypes.bool
|
63
|
+
return uidx, uvalid
|
64
|
+
|
65
|
+
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
|
66
|
+
if getenv("SYMBOLIC_DIFF"):
|
67
|
+
symbolic_idx, symbolic_valid = st_to_uops_symbolic(st, idxs, dtype)
|
68
|
+
graph_idx, graph_valid = st_to_uops_graph(st, idxs, dtype)
|
69
|
+
import ocdiff
|
70
|
+
from tinygrad.codegen.uopgraph import UOpGraph
|
71
|
+
from tinygrad.renderer.cstyle import OpenCLRenderer
|
72
|
+
|
73
|
+
def render(s1, s2):
|
74
|
+
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg="idxs")
|
75
|
+
st = tuple(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, i), s)) for i,s in enumerate([s1,s2]))
|
76
|
+
return OpenCLRenderer().render("indexing", UOpGraph(UOp(UOps.SINK, None, st)).linearize(skip_check=True).uops)
|
77
|
+
|
78
|
+
cmp_symbolic, cmp_graph = render(symbolic_idx, symbolic_valid), render(graph_idx, graph_valid)
|
79
|
+
if cmp_symbolic != cmp_graph: print(ocdiff.console_diff(f"SYMBOLIC {len(cmp_symbolic)}\n"+cmp_symbolic, f"GRAPH {len(cmp_graph)}\n"+cmp_graph))
|
80
|
+
return st_to_uops_graph(st, idxs, dtype) if getenv("UOP_IS_SYMBOLIC") else st_to_uops_symbolic(st, idxs, dtype)
|
81
|
+
|
82
|
+
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
83
|
+
# TODO: symbolic shape
|
84
|
+
if not all_int(dims): return dims
|
85
|
+
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
86
|
+
for i,m in enumerate(max_sizes):
|
87
|
+
if dims[i] * dims[i+1] <= m:
|
88
|
+
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
89
|
+
break
|
90
|
+
else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
|
91
|
+
return dims
|
92
|
+
|
93
|
+
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
94
|
+
if reverse: dims = dims[::-1]
|
95
|
+
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
96
|
+
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
97
|
+
if limited != dims:
|
98
|
+
ret = []
|
99
|
+
# cast for mypy, get_contraction won't be None
|
100
|
+
for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
|
101
|
+
if len(contraction) == 1: ret.append(idx)
|
102
|
+
else:
|
103
|
+
for c in contraction:
|
104
|
+
ret.append(idx % dims[c])
|
105
|
+
idx //= dims[c]
|
106
|
+
return ret[::-1] if reverse else ret
|
107
|
+
|
108
|
+
class IndependentLowerer:
|
109
|
+
def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
|
110
|
+
self.output_count = len(ast.src)
|
111
|
+
|
112
|
+
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
|
113
|
+
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
114
|
+
full_shape = ast.full_shape
|
115
|
+
first_upcasted = len(full_shape)-ki.upcasted
|
116
|
+
# if there's no reduce, this is first_upcasted
|
117
|
+
first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
|
118
|
+
local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
|
119
|
+
# NOTE: this is taking the first one...there may be subtlelies here with multireduces
|
120
|
+
group_for_reduces = sum([x!=y for x,y in zip(
|
121
|
+
local_loads[0].arg.st.shape[first_reduce:first_upcasted], ast.src[0].arg.st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
122
|
+
global_dims = first_reduce-ki.local_dims
|
123
|
+
|
124
|
+
if opts.has_local:
|
125
|
+
if ki.dont_use_locals:
|
126
|
+
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
127
|
+
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
128
|
+
else:
|
129
|
+
# define indexes for GPU-like execution
|
130
|
+
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
131
|
+
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
132
|
+
else:
|
133
|
+
# all loops are RANGES
|
134
|
+
self.idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
135
|
+
for i,g in enumerate(full_shape[:first_reduce])]
|
136
|
+
|
137
|
+
# reduce loops
|
138
|
+
self.idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
139
|
+
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
140
|
+
|
141
|
+
# upcast loops
|
142
|
+
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
143
|
+
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
144
|
+
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, tuple(UOp.const(dtypes.pyint, j) for j in range(0, g)), ((i,g),)))
|
145
|
+
|
146
|
+
# late indexes (group for reduce)
|
147
|
+
self.ridxs = self.idxs[:]
|
148
|
+
for a in range(first_reduce, first_reduce+group_for_reduces):
|
149
|
+
self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
150
|
+
|
151
|
+
self.uop_cache: Dict[LazyOp, UOp] = {}
|
152
|
+
return self.to_uop(ast)
|
153
|
+
|
154
|
+
def to_uop(self, x:LazyOp) -> UOp:
|
155
|
+
if uop:=self.uop_cache.get(x, None): return uop
|
156
|
+
ret = self._to_uop(x)
|
157
|
+
self.uop_cache[x] = ret
|
158
|
+
return ret
|
159
|
+
|
160
|
+
def _to_uop(self, x:LazyOp) -> UOp:
|
161
|
+
if x.op in BufferOps:
|
162
|
+
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
|
163
|
+
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
|
164
|
+
# TODO: check has_valid in UPat, not here
|
165
|
+
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
166
|
+
if x.op is BufferOps.CONST:
|
167
|
+
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
|
168
|
+
return valid.where(UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
|
169
|
+
if x.arg.idx < 0:
|
170
|
+
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype),
|
171
|
+
arg=(f"temp{-x.arg.idx}", x.arg.st.real_size()))
|
172
|
+
else:
|
173
|
+
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), x.arg.idx)
|
174
|
+
if x.op is BufferOps.LOAD:
|
175
|
+
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
176
|
+
load_dtype = x.arg.dtype.scalar()
|
177
|
+
if idx.dtype == dtypes.int.vec(3):
|
178
|
+
# this should all simplify if there's consts for id4. if not, w/e
|
179
|
+
idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]
|
180
|
+
vec_load = UOp(UOps.LOAD, load_dtype.vec(4), (buf, idx) + ((UOp.const(load_dtype.vec(4), 0), valid) if has_valid else ()) + barrier)
|
181
|
+
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load_dtype, (vec_load,), i)),
|
182
|
+
range(4), UOp.const(load_dtype, float('nan')))
|
183
|
+
return UOp(UOps.LOAD, load_dtype, (buf, idx) + ((UOp.const(load_dtype, 0), valid) if has_valid else ()) + barrier)
|
184
|
+
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
|
185
|
+
if x.arg.idx >= 0:
|
186
|
+
for oidx, ridx in zip(self.idxs, self.ridxs):
|
187
|
+
if oidx != ridx: valid = valid * oidx.eq(0)
|
188
|
+
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
189
|
+
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
|
190
|
+
|
191
|
+
in_uops = tuple(self.to_uop(y) for y in x.src)
|
192
|
+
if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
|
193
|
+
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
|
194
|
+
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
|
195
|
+
if x.op in ReduceOps:
|
196
|
+
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
|
197
|
+
if x.op is ReduceOps.WMMA:
|
198
|
+
upcast_axes = x.arg[-2]
|
199
|
+
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
200
|
+
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
|
201
|
+
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
202
|
+
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
203
|
+
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
204
|
+
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
|
205
|
+
# NOTE: always using ridxs is fine here
|
206
|
+
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg], lambda y: y.op is UOps.RANGE)
|
207
|
+
alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, x.op)]
|
208
|
+
ret = in_uops[0]
|
209
|
+
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
210
|
+
ret = UOp(UOps.CONTRACT, dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
211
|
+
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(cast(DType, ret.dtype).count)])
|
212
|
+
return UOp(UOps.REDUCE, dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
213
|
+
return in_uops[0].alu(x.op, *in_uops[1:])
|
214
|
+
|
215
|
+
def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
|
@@ -0,0 +1,310 @@
|
|
1
|
+
import math, functools
|
2
|
+
from typing import Tuple, List
|
3
|
+
from tinygrad.dtype import dtypes, DType
|
4
|
+
from tinygrad.codegen.uops import UOp
|
5
|
+
|
6
|
+
TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
|
7
|
+
|
8
|
+
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
9
|
+
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
|
10
|
+
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
11
|
+
# *** helper functions for double/quad precision arithmetics ***
|
12
|
+
def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy
|
13
|
+
def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx
|
14
|
+
def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
|
15
|
+
t = dx.recip()
|
16
|
+
qx = nx * t
|
17
|
+
qy = (ny - qx * dy) * t
|
18
|
+
return qx, qy
|
19
|
+
# *** helper functions for bit manipulation ***
|
20
|
+
def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
|
21
|
+
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
|
22
|
+
def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
|
23
|
+
|
24
|
+
def float_to_bits(d:UOp) -> UOp:
|
25
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
26
|
+
cast_to = {dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[d.dtype]
|
27
|
+
return d.bitcast(cast_to)
|
28
|
+
|
29
|
+
def bits_to_float(d:UOp, float_dtype:DType) -> UOp:
|
30
|
+
assert d.dtype in [dtypes.uint64, dtypes.uint32, dtypes.uint16]
|
31
|
+
cast_to = {dtypes.uint64: dtypes.float64, dtypes.uint32: dtypes.float32, dtypes.uint16: float_dtype}[d.dtype]
|
32
|
+
return d.bitcast(cast_to)
|
33
|
+
# **** utils ****
|
34
|
+
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
|
35
|
+
def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
36
|
+
|
37
|
+
def rintk(d:UOp) -> UOp:
|
38
|
+
"""ceiling(d:float) -> int"""
|
39
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
40
|
+
return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
41
|
+
return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t)
|
42
|
+
|
43
|
+
def pow2if(q:UOp, float_dtype:DType):
|
44
|
+
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
45
|
+
assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32)
|
46
|
+
final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype]
|
47
|
+
return shl((q + (exponent_bias(final_dtype)+1)), significand_bits(final_dtype)).bitcast(final_dtype)
|
48
|
+
|
49
|
+
def ilogb2k(d:UOp) -> UOp:
|
50
|
+
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
51
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
52
|
+
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
|
53
|
+
# -1 <= ilog2bk(d) <= 128
|
54
|
+
# ((float_to_bits(d) >> significand_bits(dtype)) & exponent_mask(dtype)) - exponent_bias(dtype)
|
55
|
+
return (shr(dint, significand_bits(d.dtype)) & exponent_mask(d.dtype)) - (exponent_bias(d.dtype)+1)
|
56
|
+
|
57
|
+
def ldexp3k(d:UOp, e:UOp) -> UOp:
|
58
|
+
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
|
59
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
60
|
+
dtype = d.dtype
|
61
|
+
cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
|
62
|
+
e = e.cast(cast_map[d.dtype])
|
63
|
+
m1 = d.bitcast(cast_map[d.dtype])
|
64
|
+
m2 = shl(e, significand_bits(d.dtype))
|
65
|
+
return (m1 + m2).bitcast(d.dtype).cast(dtype)
|
66
|
+
|
67
|
+
def ldexp2k(d:UOp, e:UOp) -> UOp:
|
68
|
+
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
|
69
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
70
|
+
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
71
|
+
|
72
|
+
def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
73
|
+
"""frexp(v) -> (mantissa, exponent)"""
|
74
|
+
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
75
|
+
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
76
|
+
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
|
77
|
+
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3C00}[v.dtype]
|
78
|
+
bias = {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 15}[v.dtype]
|
79
|
+
bits = float_to_bits(v)
|
80
|
+
exponent = shr(bits, significand_bits(v.dtype)) & exponent_mask(v.dtype)
|
81
|
+
exponent_zero = exponent.ne(0.0)
|
82
|
+
result_f = bits_to_float((bits & m1) | m2, v.dtype)
|
83
|
+
value = exponent_zero.where(result_f, v)
|
84
|
+
exp = exponent + (-bias)
|
85
|
+
exp = exponent_zero.where(exp, exp.const(0))
|
86
|
+
if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
|
87
|
+
return value, exp
|
88
|
+
|
89
|
+
def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
|
90
|
+
|
91
|
+
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
|
92
|
+
# *** reduction algorithms for sine ***
|
93
|
+
def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
94
|
+
"""
|
95
|
+
Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
|
96
|
+
39800.0 <= d <= +Inf
|
97
|
+
Returns a tuple of `(r, q)`:
|
98
|
+
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
|
99
|
+
ensuring that `r` is in the range of [0, pi/2).
|
100
|
+
- `q`[int32] is an integer taking values 0,1,2 or 3, corresponding to the quadrant of the original angle `d`.
|
101
|
+
"""
|
102
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
103
|
+
two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410]
|
104
|
+
|
105
|
+
input_dtype: DType = d.dtype
|
106
|
+
dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
|
107
|
+
acc_dtype = dtypes.uint64
|
108
|
+
|
109
|
+
f, e = frexp(d)
|
110
|
+
ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
|
111
|
+
i = shr(e.cast(dtypes.uint64), 5)
|
112
|
+
e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32)
|
113
|
+
offset = -e + 32
|
114
|
+
|
115
|
+
def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to)
|
116
|
+
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
117
|
+
"""an = two_over_pi_f[i+offset]"""
|
118
|
+
if count+offset <= len(two_over_pi_f[0:-2]):
|
119
|
+
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset]))
|
120
|
+
return an
|
121
|
+
def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
|
122
|
+
def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
|
123
|
+
def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
|
124
|
+
# a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
|
125
|
+
a1 = _take(i.const(0).cast(dtypes.uint32), 0)
|
126
|
+
a2 = _take(i.const(0).cast(dtypes.uint32), 1)
|
127
|
+
a3 = _take(i.const(0).cast(dtypes.uint32), 2)
|
128
|
+
a4 = _take(i.const(0).cast(dtypes.uint32), 3)
|
129
|
+
# Note: e >= 1 for all numbers d >= 1.0. assume e != 0
|
130
|
+
hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
|
131
|
+
mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
|
132
|
+
lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset)
|
133
|
+
|
134
|
+
def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
|
135
|
+
p = _hp_mul(ia, lo)
|
136
|
+
p = _hp_mul(ia, mi) + shr(p, 32)
|
137
|
+
p = shl(_hp_mul(ia, hi), 32) + p
|
138
|
+
|
139
|
+
q = shr(p, 62).cast(dtypes.int32)
|
140
|
+
p = p & 0x3fffffffffffffff
|
141
|
+
r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
|
142
|
+
|
143
|
+
# if fraction >= 0.5, r -= pi/2, q += 1
|
144
|
+
return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
|
145
|
+
|
146
|
+
def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
147
|
+
"""
|
148
|
+
Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
|
149
|
+
0 <= abs(d) <= 39800.0
|
150
|
+
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
|
151
|
+
"""
|
152
|
+
m_1_pi = 0.318309886183790671537767526745028724
|
153
|
+
qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
|
154
|
+
def _quadrant(x:UOp) -> UOp:
|
155
|
+
if x.dtype == dtypes.float64: return rintk(mla(d, d.const(m_1_pi), -qdh)).cast(x.dtype)
|
156
|
+
return rintk(x * m_1_pi).cast(x.dtype)
|
157
|
+
def _reduce_d(x:UOp, q:UOp):
|
158
|
+
if x.dtype == dtypes.float64:
|
159
|
+
d = mla(qdh, x.const(-3.1415926218032836914), x)
|
160
|
+
d = mla(q, x.const(-3.1415926218032836914), d)
|
161
|
+
d = mla(qdh, x.const(-3.1786509424591713469e-08), d)
|
162
|
+
d = mla(q, x.const(-3.1786509424591713469e-08), d)
|
163
|
+
d = mla(qdh, x.const(-1.2246467864107188502e-16), d)
|
164
|
+
d = mla(q, x.const(-1.2246467864107188502e-16), d)
|
165
|
+
d = mla(qdh + q, x.const(-1.2736634327021899816e-24), d)
|
166
|
+
elif x.dtype == dtypes.float16:
|
167
|
+
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
168
|
+
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
169
|
+
else:
|
170
|
+
d = mla(q, x.const(-3.1414794921875), x)
|
171
|
+
d = mla(q, x.const(-0.00011315941810607910156), d)
|
172
|
+
d = mla(q, x.const(-1.9841872589410058936e-09), d)
|
173
|
+
d = mla(q, x.const(-1.2154201256553420762e-10), d)
|
174
|
+
return d
|
175
|
+
return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
|
176
|
+
# *** approximate sine on small angle. ***
|
177
|
+
def trig_poly(d:UOp, coeff32, coeff64):
|
178
|
+
u = None
|
179
|
+
s = d * d
|
180
|
+
if d.dtype == dtypes.float64:
|
181
|
+
s2 = s * s
|
182
|
+
s4 = s2 * s2
|
183
|
+
def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0)))
|
184
|
+
def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return mla(x4, __poly4(x, x2, c7, c6, c5, c4), __poly4(x, x2, c3, c2, c1, c0))
|
185
|
+
u = __poly8(s, s2, s4, *coeff64[:-1])
|
186
|
+
u = mla(u, s, d.const(coeff64[-1]))
|
187
|
+
else:
|
188
|
+
u = polyN(s.const(coeff32[0]), s, coeff32[1:])
|
189
|
+
return mla(s, u * d, d)
|
190
|
+
# approximate sine on [-pi/2, pi/2]
|
191
|
+
def sin_poly(d:UOp) -> UOp:
|
192
|
+
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938],
|
193
|
+
[-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
|
194
|
+
-2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
|
195
|
+
-0.166666666666666657414808])
|
196
|
+
|
197
|
+
def sin_poly_small(d:UOp, q:UOp) -> UOp:
|
198
|
+
def _ifand(n:int): return (q & n).ne(0)
|
199
|
+
r = sin_poly(d)
|
200
|
+
return r * _ifand(1).where(r.const(-1), r.const(1))
|
201
|
+
|
202
|
+
def sin_poly_large(d:UOp, q:UOp) -> UOp:
|
203
|
+
def _ifand(n:int): return (q & n).ne(0)
|
204
|
+
d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0))
|
205
|
+
r = sin_poly(d)
|
206
|
+
return r * _ifand(2).where(r.const(-1), r.const(1))
|
207
|
+
|
208
|
+
# *** toplevel functions for xsin/xlog2/xexp2 ***
|
209
|
+
|
210
|
+
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
211
|
+
"""
|
212
|
+
Implements a 1.0 ULP approximation for UnaryOps.SIN.
|
213
|
+
- fast=True assumes x <= switch_over.
|
214
|
+
- switch_over is the threshold for switching to payne_hanek_reduction.
|
215
|
+
"""
|
216
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
217
|
+
reduction_algo = cody_waite_reduction if fast else payne_hanek_reduction
|
218
|
+
# mask +-inf/nan as zero
|
219
|
+
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
|
220
|
+
# x_sign = sign(x)
|
221
|
+
x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
|
222
|
+
x_abs = x * x_sign
|
223
|
+
r, q = reduction_algo(x_abs)
|
224
|
+
if fast: result = sin_poly_small(r, q)
|
225
|
+
else:
|
226
|
+
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
227
|
+
switch_over_map = x_abs.lt(switch_over)
|
228
|
+
r_fast, q_fast = cody_waite_reduction(x_abs)
|
229
|
+
r = switch_over_map.where(r_fast, r)
|
230
|
+
q = switch_over_map.where(q_fast, q)
|
231
|
+
result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q))
|
232
|
+
result = result * x_sign # adjusts the sign for abs(x).
|
233
|
+
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
234
|
+
return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
|
235
|
+
|
236
|
+
def xexp2(d:UOp) -> UOp:
|
237
|
+
"""
|
238
|
+
Implements a 1.0 ULP approximation for UnaryOps.EXP2
|
239
|
+
- Paper: https://arxiv.org/pdf/2001.09258
|
240
|
+
"""
|
241
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
242
|
+
fp64_p = d.dtype == dtypes.float64
|
243
|
+
# mask +=inf/nan as zero.
|
244
|
+
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
|
245
|
+
q = rintk(x)
|
246
|
+
# s = d - round(d)
|
247
|
+
s = x - q.cast(x.dtype)
|
248
|
+
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
249
|
+
if fp64_p:
|
250
|
+
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5,
|
251
|
+
0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2,
|
252
|
+
0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
|
253
|
+
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
254
|
+
else:
|
255
|
+
u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
|
256
|
+
u = ldexp2k(u, q) # u*2^q
|
257
|
+
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype]
|
258
|
+
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
|
259
|
+
# Replace x >= upper with +inf
|
260
|
+
u = x.ne(upper).where(u, x.const(math.inf))
|
261
|
+
u = x.lt(upper).where(u, x.const(math.inf))
|
262
|
+
# Replace x <= lower with zero.
|
263
|
+
u = x.lt(lower).where(x.const(0.0), u)
|
264
|
+
# x=NaN never satisfies x < Inf. (for fastmode)
|
265
|
+
u = x.lt(math.inf).where(u, u.const(math.nan))
|
266
|
+
# exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
|
267
|
+
return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u)
|
268
|
+
|
269
|
+
def xlog2(d:UOp) -> UOp:
|
270
|
+
"""
|
271
|
+
Implements a 1.0 ULP approximation for UnaryOps.LOG2
|
272
|
+
Paper: https://arxiv.org/pdf/2001.09258
|
273
|
+
"""
|
274
|
+
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
275
|
+
fp64_p = d.dtype == dtypes.float64
|
276
|
+
FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
277
|
+
d_orig = d
|
278
|
+
denormal_map = d.lt(FLT_MIN)
|
279
|
+
for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
|
280
|
+
|
281
|
+
e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype)
|
282
|
+
m = ldexp3k(d, -e)
|
283
|
+
e = denormal_map.where(e + (-64), e)
|
284
|
+
|
285
|
+
if fp64_p:
|
286
|
+
x = (m - 1.0) * (m + 1.0).recip()
|
287
|
+
x2 = x * x
|
288
|
+
t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
289
|
+
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
290
|
+
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
|
291
|
+
r = mla(t, x * x2, s_hi + s_lo)
|
292
|
+
else:
|
293
|
+
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0)))
|
294
|
+
x2 = xx * xx
|
295
|
+
t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120])
|
296
|
+
sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08)))
|
297
|
+
sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t)
|
298
|
+
r = sx + sy
|
299
|
+
# log2(Inf) = Inf
|
300
|
+
r = d_orig.ne(math.inf).where(r, r.const(math.inf))
|
301
|
+
# log2(x=-0.01) = NaN. where x < 0
|
302
|
+
r = d_orig.lt(-0.0).where(r.const(math.nan), r)
|
303
|
+
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
304
|
+
# log2_zero = the value of unmasked xlog2(0.0).
|
305
|
+
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
|
306
|
+
r = r.ne(log2_zero).where(r, r.const(-math.inf))
|
307
|
+
# log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
|
308
|
+
r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
|
309
|
+
# log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
310
|
+
return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))
|