tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. tinygrad/codegen/kernel.py +248 -115
  2. tinygrad/codegen/lowerer.py +215 -0
  3. tinygrad/codegen/transcendental.py +310 -0
  4. tinygrad/codegen/uopgraph.py +622 -0
  5. tinygrad/codegen/uops.py +235 -393
  6. tinygrad/device.py +428 -69
  7. tinygrad/dtype.py +18 -4
  8. tinygrad/engine/graph.py +19 -32
  9. tinygrad/engine/jit.py +148 -70
  10. tinygrad/engine/realize.py +127 -51
  11. tinygrad/engine/schedule.py +259 -216
  12. tinygrad/engine/search.py +29 -22
  13. tinygrad/function.py +9 -0
  14. tinygrad/helpers.py +87 -49
  15. tinygrad/lazy.py +34 -35
  16. tinygrad/multi.py +41 -36
  17. tinygrad/nn/__init__.py +39 -22
  18. tinygrad/nn/state.py +3 -3
  19. tinygrad/ops.py +63 -62
  20. tinygrad/renderer/__init__.py +43 -21
  21. tinygrad/renderer/assembly.py +104 -106
  22. tinygrad/renderer/cstyle.py +87 -60
  23. tinygrad/renderer/llvmir.py +21 -30
  24. tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
  25. tinygrad/runtime/autogen/cuda.py +6 -162
  26. tinygrad/runtime/autogen/kfd.py +32 -0
  27. tinygrad/runtime/autogen/libc.py +4260 -0
  28. tinygrad/runtime/autogen/nvrtc.py +579 -0
  29. tinygrad/runtime/graph/clang.py +2 -2
  30. tinygrad/runtime/graph/cuda.py +8 -11
  31. tinygrad/runtime/graph/hcq.py +120 -107
  32. tinygrad/runtime/graph/metal.py +18 -15
  33. tinygrad/runtime/ops_amd.py +197 -305
  34. tinygrad/runtime/ops_clang.py +2 -2
  35. tinygrad/runtime/ops_cuda.py +36 -94
  36. tinygrad/runtime/ops_disk.py +3 -7
  37. tinygrad/runtime/ops_gpu.py +4 -2
  38. tinygrad/runtime/ops_hip.py +70 -0
  39. tinygrad/runtime/ops_metal.py +38 -27
  40. tinygrad/runtime/ops_nv.py +283 -363
  41. tinygrad/runtime/ops_python.py +26 -30
  42. tinygrad/runtime/support/compiler_cuda.py +78 -0
  43. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
  44. tinygrad/runtime/support/elf.py +38 -0
  45. tinygrad/shape/shapetracker.py +5 -14
  46. tinygrad/shape/symbolic.py +4 -8
  47. tinygrad/shape/view.py +34 -22
  48. tinygrad/tensor.py +399 -97
  49. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
  50. tinygrad-0.9.2.dist-info/RECORD +70 -0
  51. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
  52. tinygrad/codegen/linearizer.py +0 -528
  53. tinygrad-0.9.1.dist-info/RECORD +0 -63
  54. /tinygrad/runtime/{driver → support}/__init__.py +0 -0
  55. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
  56. {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, cast, Optional, Any, Dict
3
+ import functools
4
+ from tinygrad.shape.shapetracker import ShapeTracker, View
5
+ from tinygrad.shape.symbolic import sint
6
+ from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
7
+ from tinygrad.ops import BufferOps, LazyOp, ReduceOps, UnaryOps, MetaOps, KernelInfo, MemBuffer, BinaryOps
8
+ from tinygrad.codegen.uops import UOp, UOps
9
+ from tinygrad.renderer import Renderer
10
+ from tinygrad.helpers import getenv, all_int, get_contraction, prod, partition, flatten
11
+
12
+ # TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
13
+ from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
14
+ def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
15
+ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
16
+ MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
17
+ DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
18
+ ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
19
+ LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
20
+ Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
21
+ UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
22
+ SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
23
+ AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
24
+
25
+ def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
26
+ # TODO: dtypes.realint
27
+ iexpr = variable_to_uop(view.offset)
28
+ for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
29
+ if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
30
+ if m is not None:
31
+ if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
32
+ if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
33
+ return iexpr, vexpr
34
+
35
+ # TODO: change this once UOps is ready to replace symbolic
36
+ def st_to_uops_graph(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
37
+ idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
38
+ for view in reversed(st.views[0:-1]):
39
+ view = view.minify()
40
+ acc, idxs = 1, []
41
+ for _d in reversed(view.shape):
42
+ d = variable_to_uop(_d)
43
+ idxs.append((idx//acc)%d)
44
+ acc *= d
45
+ idx, valid = _uop_view(view, idxs[::-1], valid)
46
+ if isinstance(dtype, ImageDType):
47
+ idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
48
+ return idx, valid
49
+
50
+ # TODO: this is the old one, delete when ready
51
+ def st_to_uops_symbolic(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
52
+ fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
53
+ idx, valid = st.expr_idxs(fake_idxs)
54
+ ctx = dict(zip(fake_idxs, idxs))
55
+ uvalid = valid.render(render_ops, ctx)
56
+ if isinstance(dtype, ImageDType):
57
+ image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
58
+ uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
59
+ else:
60
+ uidx = idx.render(render_ops, ctx)
61
+ if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
62
+ assert uvalid.dtype == dtypes.bool
63
+ return uidx, uvalid
64
+
65
+ def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
66
+ if getenv("SYMBOLIC_DIFF"):
67
+ symbolic_idx, symbolic_valid = st_to_uops_symbolic(st, idxs, dtype)
68
+ graph_idx, graph_valid = st_to_uops_graph(st, idxs, dtype)
69
+ import ocdiff
70
+ from tinygrad.codegen.uopgraph import UOpGraph
71
+ from tinygrad.renderer.cstyle import OpenCLRenderer
72
+
73
+ def render(s1, s2):
74
+ glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg="idxs")
75
+ st = tuple(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, i), s)) for i,s in enumerate([s1,s2]))
76
+ return OpenCLRenderer().render("indexing", UOpGraph(UOp(UOps.SINK, None, st)).linearize(skip_check=True).uops)
77
+
78
+ cmp_symbolic, cmp_graph = render(symbolic_idx, symbolic_valid), render(graph_idx, graph_valid)
79
+ if cmp_symbolic != cmp_graph: print(ocdiff.console_diff(f"SYMBOLIC {len(cmp_symbolic)}\n"+cmp_symbolic, f"GRAPH {len(cmp_graph)}\n"+cmp_graph))
80
+ return st_to_uops_graph(st, idxs, dtype) if getenv("UOP_IS_SYMBOLIC") else st_to_uops_symbolic(st, idxs, dtype)
81
+
82
+ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
83
+ # TODO: symbolic shape
84
+ if not all_int(dims): return dims
85
+ while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
86
+ for i,m in enumerate(max_sizes):
87
+ if dims[i] * dims[i+1] <= m:
88
+ dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
89
+ break
90
+ else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}")
91
+ return dims
92
+
93
+ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
94
+ if reverse: dims = dims[::-1]
95
+ limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
96
+ ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
97
+ if limited != dims:
98
+ ret = []
99
+ # cast for mypy, get_contraction won't be None
100
+ for idx, contraction in zip(raw_idxs, cast(List[List[int]], get_contraction(dims, limited))):
101
+ if len(contraction) == 1: ret.append(idx)
102
+ else:
103
+ for c in contraction:
104
+ ret.append(idx % dims[c])
105
+ idx //= dims[c]
106
+ return ret[::-1] if reverse else ret
107
+
108
+ class IndependentLowerer:
109
+ def lower(self, ast:LazyOp, opts:Renderer) -> UOp:
110
+ self.output_count = len(ast.src)
111
+
112
+ ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
113
+ # NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
114
+ full_shape = ast.full_shape
115
+ first_upcasted = len(full_shape)-ki.upcasted
116
+ # if there's no reduce, this is first_upcasted
117
+ first_reduce = [x!=y for x,y in zip(ast.src[0].arg.st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
118
+ local_loads = [x for x in ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == -1]
119
+ # NOTE: this is taking the first one...there may be subtlelies here with multireduces
120
+ group_for_reduces = sum([x!=y for x,y in zip(
121
+ local_loads[0].arg.st.shape[first_reduce:first_upcasted], ast.src[0].arg.st.shape[first_reduce:first_upcasted])]) if local_loads else 0
122
+ global_dims = first_reduce-ki.local_dims
123
+
124
+ if opts.has_local:
125
+ if ki.dont_use_locals:
126
+ assert ki.local_dims == 0, "can't use locals if there's no local dims"
127
+ self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
128
+ else:
129
+ # define indexes for GPU-like execution
130
+ self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
131
+ get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
132
+ else:
133
+ # all loops are RANGES
134
+ self.idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
135
+ for i,g in enumerate(full_shape[:first_reduce])]
136
+
137
+ # reduce loops
138
+ self.idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
139
+ for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
140
+
141
+ # upcast loops
142
+ for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
143
+ assert isinstance(g, int), "needs to be int to upcast/unroll"
144
+ self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, tuple(UOp.const(dtypes.pyint, j) for j in range(0, g)), ((i,g),)))
145
+
146
+ # late indexes (group for reduce)
147
+ self.ridxs = self.idxs[:]
148
+ for a in range(first_reduce, first_reduce+group_for_reduces):
149
+ self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
150
+
151
+ self.uop_cache: Dict[LazyOp, UOp] = {}
152
+ return self.to_uop(ast)
153
+
154
+ def to_uop(self, x:LazyOp) -> UOp:
155
+ if uop:=self.uop_cache.get(x, None): return uop
156
+ ret = self._to_uop(x)
157
+ self.uop_cache[x] = ret
158
+ return ret
159
+
160
+ def _to_uop(self, x:LazyOp) -> UOp:
161
+ if x.op in BufferOps:
162
+ idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
163
+ x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
164
+ # TODO: check has_valid in UPat, not here
165
+ has_valid = valid.op is not UOps.CONST or valid.arg is not True
166
+ if x.op is BufferOps.CONST:
167
+ dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
168
+ return valid.where(UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
169
+ if x.arg.idx < 0:
170
+ buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype),
171
+ arg=(f"temp{-x.arg.idx}", x.arg.st.real_size()))
172
+ else:
173
+ buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), x.arg.idx)
174
+ if x.op is BufferOps.LOAD:
175
+ barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
176
+ load_dtype = x.arg.dtype.scalar()
177
+ if idx.dtype == dtypes.int.vec(3):
178
+ # this should all simplify if there's consts for id4. if not, w/e
179
+ idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]
180
+ vec_load = UOp(UOps.LOAD, load_dtype.vec(4), (buf, idx) + ((UOp.const(load_dtype.vec(4), 0), valid) if has_valid else ()) + barrier)
181
+ return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load_dtype, (vec_load,), i)),
182
+ range(4), UOp.const(load_dtype, float('nan')))
183
+ return UOp(UOps.LOAD, load_dtype, (buf, idx) + ((UOp.const(load_dtype, 0), valid) if has_valid else ()) + barrier)
184
+ # NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
185
+ if x.arg.idx >= 0:
186
+ for oidx, ridx in zip(self.idxs, self.ridxs):
187
+ if oidx != ridx: valid = valid * oidx.eq(0)
188
+ has_valid = valid.op is not UOps.CONST or valid.arg is not True
189
+ return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
190
+
191
+ in_uops = tuple(self.to_uop(y) for y in x.src)
192
+ if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
193
+ if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
194
+ if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
195
+ if x.op in ReduceOps:
196
+ dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
197
+ if x.op is ReduceOps.WMMA:
198
+ upcast_axes = x.arg[-2]
199
+ wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
200
+ ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
201
+ UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
202
+ UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
203
+ UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
204
+ return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=upcast_axes[2])
205
+ # NOTE: always using ridxs is fine here
206
+ reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg], lambda y: y.op is UOps.RANGE)
207
+ alu_op = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, x.op)]
208
+ ret = in_uops[0]
209
+ if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
210
+ ret = UOp(UOps.CONTRACT, dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
211
+ ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(cast(DType, ret.dtype).count)])
212
+ return UOp(UOps.REDUCE, dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
213
+ return in_uops[0].alu(x.op, *in_uops[1:])
214
+
215
+ def lazyop_to_uop(ast:LazyOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
@@ -0,0 +1,310 @@
1
+ import math, functools
2
+ from typing import Tuple, List
3
+ from tinygrad.dtype import dtypes, DType
4
+ from tinygrad.codegen.uops import UOp
5
+
6
+ TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
7
+
8
+ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
9
+ """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
10
+ return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
11
+ # *** helper functions for double/quad precision arithmetics ***
12
+ def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy
13
+ def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx
14
+ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
15
+ t = dx.recip()
16
+ qx = nx * t
17
+ qy = (ny - qx * dy) * t
18
+ return qx, qy
19
+ # *** helper functions for bit manipulation ***
20
+ def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
21
+ def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
22
+ def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
23
+
24
+ def float_to_bits(d:UOp) -> UOp:
25
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
26
+ cast_to = {dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[d.dtype]
27
+ return d.bitcast(cast_to)
28
+
29
+ def bits_to_float(d:UOp, float_dtype:DType) -> UOp:
30
+ assert d.dtype in [dtypes.uint64, dtypes.uint32, dtypes.uint16]
31
+ cast_to = {dtypes.uint64: dtypes.float64, dtypes.uint32: dtypes.float32, dtypes.uint16: float_dtype}[d.dtype]
32
+ return d.bitcast(cast_to)
33
+ # **** utils ****
34
+ def shr(x:UOp, y:int) -> UOp: return x // (2**y)
35
+ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
36
+
37
+ def rintk(d:UOp) -> UOp:
38
+ """ceiling(d:float) -> int"""
39
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
40
+ return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
41
+ return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t)
42
+
43
+ def pow2if(q:UOp, float_dtype:DType):
44
+ """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
45
+ assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32)
46
+ final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype]
47
+ return shl((q + (exponent_bias(final_dtype)+1)), significand_bits(final_dtype)).bitcast(final_dtype)
48
+
49
+ def ilogb2k(d:UOp) -> UOp:
50
+ """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
51
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
52
+ dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
53
+ # -1 <= ilog2bk(d) <= 128
54
+ # ((float_to_bits(d) >> significand_bits(dtype)) & exponent_mask(dtype)) - exponent_bias(dtype)
55
+ return (shr(dint, significand_bits(d.dtype)) & exponent_mask(d.dtype)) - (exponent_bias(d.dtype)+1)
56
+
57
+ def ldexp3k(d:UOp, e:UOp) -> UOp:
58
+ """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
59
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
60
+ dtype = d.dtype
61
+ cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
62
+ e = e.cast(cast_map[d.dtype])
63
+ m1 = d.bitcast(cast_map[d.dtype])
64
+ m2 = shl(e, significand_bits(d.dtype))
65
+ return (m1 + m2).bitcast(d.dtype).cast(dtype)
66
+
67
+ def ldexp2k(d:UOp, e:UOp) -> UOp:
68
+ """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
69
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
70
+ return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
71
+
72
+ def frexp(v:UOp) -> Tuple[UOp, UOp]:
73
+ """frexp(v) -> (mantissa, exponent)"""
74
+ assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
75
+ # m1 = masks for mantissa, m2 = masks to normalize the mantissa.
76
+ m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
77
+ m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3C00}[v.dtype]
78
+ bias = {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 15}[v.dtype]
79
+ bits = float_to_bits(v)
80
+ exponent = shr(bits, significand_bits(v.dtype)) & exponent_mask(v.dtype)
81
+ exponent_zero = exponent.ne(0.0)
82
+ result_f = bits_to_float((bits & m1) | m2, v.dtype)
83
+ value = exponent_zero.where(result_f, v)
84
+ exp = exponent + (-bias)
85
+ exp = exponent_zero.where(exp, exp.const(0))
86
+ if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
87
+ return value, exp
88
+
89
+ def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
90
+
91
+ def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
92
+ # *** reduction algorithms for sine ***
93
+ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
94
+ """
95
+ Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
96
+ 39800.0 <= d <= +Inf
97
+ Returns a tuple of `(r, q)`:
98
+ - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
99
+ ensuring that `r` is in the range of [0, pi/2).
100
+ - `q`[int32] is an integer taking values 0,1,2 or 3, corresponding to the quadrant of the original angle `d`.
101
+ """
102
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
103
+ two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410]
104
+
105
+ input_dtype: DType = d.dtype
106
+ dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
107
+ acc_dtype = dtypes.uint64
108
+
109
+ f, e = frexp(d)
110
+ ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
111
+ i = shr(e.cast(dtypes.uint64), 5)
112
+ e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32)
113
+ offset = -e + 32
114
+
115
+ def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to)
116
+ def _take(an:UOp, offset:int, count:int=0) -> UOp:
117
+ """an = two_over_pi_f[i+offset]"""
118
+ if count+offset <= len(two_over_pi_f[0:-2]):
119
+ an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset]))
120
+ return an
121
+ def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
122
+ def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
123
+ def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
124
+ # a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
125
+ a1 = _take(i.const(0).cast(dtypes.uint32), 0)
126
+ a2 = _take(i.const(0).cast(dtypes.uint32), 1)
127
+ a3 = _take(i.const(0).cast(dtypes.uint32), 2)
128
+ a4 = _take(i.const(0).cast(dtypes.uint32), 3)
129
+ # Note: e >= 1 for all numbers d >= 1.0. assume e != 0
130
+ hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
131
+ mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
132
+ lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset)
133
+
134
+ def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
135
+ p = _hp_mul(ia, lo)
136
+ p = _hp_mul(ia, mi) + shr(p, 32)
137
+ p = shl(_hp_mul(ia, hi), 32) + p
138
+
139
+ q = shr(p, 62).cast(dtypes.int32)
140
+ p = p & 0x3fffffffffffffff
141
+ r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
142
+
143
+ # if fraction >= 0.5, r -= pi/2, q += 1
144
+ return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
145
+
146
+ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
147
+ """
148
+ Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
149
+ 0 <= abs(d) <= 39800.0
150
+ Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
151
+ """
152
+ m_1_pi = 0.318309886183790671537767526745028724
153
+ qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
154
+ def _quadrant(x:UOp) -> UOp:
155
+ if x.dtype == dtypes.float64: return rintk(mla(d, d.const(m_1_pi), -qdh)).cast(x.dtype)
156
+ return rintk(x * m_1_pi).cast(x.dtype)
157
+ def _reduce_d(x:UOp, q:UOp):
158
+ if x.dtype == dtypes.float64:
159
+ d = mla(qdh, x.const(-3.1415926218032836914), x)
160
+ d = mla(q, x.const(-3.1415926218032836914), d)
161
+ d = mla(qdh, x.const(-3.1786509424591713469e-08), d)
162
+ d = mla(q, x.const(-3.1786509424591713469e-08), d)
163
+ d = mla(qdh, x.const(-1.2246467864107188502e-16), d)
164
+ d = mla(q, x.const(-1.2246467864107188502e-16), d)
165
+ d = mla(qdh + q, x.const(-1.2736634327021899816e-24), d)
166
+ elif x.dtype == dtypes.float16:
167
+ # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
168
+ d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
169
+ else:
170
+ d = mla(q, x.const(-3.1414794921875), x)
171
+ d = mla(q, x.const(-0.00011315941810607910156), d)
172
+ d = mla(q, x.const(-1.9841872589410058936e-09), d)
173
+ d = mla(q, x.const(-1.2154201256553420762e-10), d)
174
+ return d
175
+ return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
176
+ # *** approximate sine on small angle. ***
177
+ def trig_poly(d:UOp, coeff32, coeff64):
178
+ u = None
179
+ s = d * d
180
+ if d.dtype == dtypes.float64:
181
+ s2 = s * s
182
+ s4 = s2 * s2
183
+ def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0)))
184
+ def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return mla(x4, __poly4(x, x2, c7, c6, c5, c4), __poly4(x, x2, c3, c2, c1, c0))
185
+ u = __poly8(s, s2, s4, *coeff64[:-1])
186
+ u = mla(u, s, d.const(coeff64[-1]))
187
+ else:
188
+ u = polyN(s.const(coeff32[0]), s, coeff32[1:])
189
+ return mla(s, u * d, d)
190
+ # approximate sine on [-pi/2, pi/2]
191
+ def sin_poly(d:UOp) -> UOp:
192
+ return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938],
193
+ [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
194
+ -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
195
+ -0.166666666666666657414808])
196
+
197
+ def sin_poly_small(d:UOp, q:UOp) -> UOp:
198
+ def _ifand(n:int): return (q & n).ne(0)
199
+ r = sin_poly(d)
200
+ return r * _ifand(1).where(r.const(-1), r.const(1))
201
+
202
+ def sin_poly_large(d:UOp, q:UOp) -> UOp:
203
+ def _ifand(n:int): return (q & n).ne(0)
204
+ d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0))
205
+ r = sin_poly(d)
206
+ return r * _ifand(2).where(r.const(-1), r.const(1))
207
+
208
+ # *** toplevel functions for xsin/xlog2/xexp2 ***
209
+
210
+ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
211
+ """
212
+ Implements a 1.0 ULP approximation for UnaryOps.SIN.
213
+ - fast=True assumes x <= switch_over.
214
+ - switch_over is the threshold for switching to payne_hanek_reduction.
215
+ """
216
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
217
+ reduction_algo = cody_waite_reduction if fast else payne_hanek_reduction
218
+ # mask +-inf/nan as zero
219
+ x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
220
+ # x_sign = sign(x)
221
+ x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
222
+ x_abs = x * x_sign
223
+ r, q = reduction_algo(x_abs)
224
+ if fast: result = sin_poly_small(r, q)
225
+ else:
226
+ # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
227
+ switch_over_map = x_abs.lt(switch_over)
228
+ r_fast, q_fast = cody_waite_reduction(x_abs)
229
+ r = switch_over_map.where(r_fast, r)
230
+ q = switch_over_map.where(q_fast, q)
231
+ result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q))
232
+ result = result * x_sign # adjusts the sign for abs(x).
233
+ # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
234
+ return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
235
+
236
+ def xexp2(d:UOp) -> UOp:
237
+ """
238
+ Implements a 1.0 ULP approximation for UnaryOps.EXP2
239
+ - Paper: https://arxiv.org/pdf/2001.09258
240
+ """
241
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
242
+ fp64_p = d.dtype == dtypes.float64
243
+ # mask +=inf/nan as zero.
244
+ x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
245
+ q = rintk(x)
246
+ # s = d - round(d)
247
+ s = x - q.cast(x.dtype)
248
+ # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
249
+ if fp64_p:
250
+ u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5,
251
+ 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2,
252
+ 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
253
+ 0.6931471805599452862e+0, 0.1000000000000000000e+1])
254
+ else:
255
+ u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
256
+ u = ldexp2k(u, q) # u*2^q
257
+ upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype]
258
+ lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
259
+ # Replace x >= upper with +inf
260
+ u = x.ne(upper).where(u, x.const(math.inf))
261
+ u = x.lt(upper).where(u, x.const(math.inf))
262
+ # Replace x <= lower with zero.
263
+ u = x.lt(lower).where(x.const(0.0), u)
264
+ # x=NaN never satisfies x < Inf. (for fastmode)
265
+ u = x.lt(math.inf).where(u, u.const(math.nan))
266
+ # exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
267
+ return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u)
268
+
269
+ def xlog2(d:UOp) -> UOp:
270
+ """
271
+ Implements a 1.0 ULP approximation for UnaryOps.LOG2
272
+ Paper: https://arxiv.org/pdf/2001.09258
273
+ """
274
+ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
275
+ fp64_p = d.dtype == dtypes.float64
276
+ FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4)
277
+ d_orig = d
278
+ denormal_map = d.lt(FLT_MIN)
279
+ for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
280
+
281
+ e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype)
282
+ m = ldexp3k(d, -e)
283
+ e = denormal_map.where(e + (-64), e)
284
+
285
+ if fp64_p:
286
+ x = (m - 1.0) * (m + 1.0).recip()
287
+ x2 = x * x
288
+ t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
289
+ 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
290
+ s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
291
+ r = mla(t, x * x2, s_hi + s_lo)
292
+ else:
293
+ xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0)))
294
+ x2 = xx * xx
295
+ t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120])
296
+ sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08)))
297
+ sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t)
298
+ r = sx + sy
299
+ # log2(Inf) = Inf
300
+ r = d_orig.ne(math.inf).where(r, r.const(math.inf))
301
+ # log2(x=-0.01) = NaN. where x < 0
302
+ r = d_orig.lt(-0.0).where(r.const(math.nan), r)
303
+ # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
304
+ # log2_zero = the value of unmasked xlog2(0.0).
305
+ log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
306
+ r = r.ne(log2_zero).where(r, r.const(-math.inf))
307
+ # log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
308
+ r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
309
+ # log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
310
+ return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))