numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,467 @@
1
+ /**
2
+ * @brief SIMD-accelerated Trigonometric Functions.
3
+ * @file include/numkong/trigonometry.h
4
+ * @author Ash Vardanian
5
+ * @date July 1, 2023
6
+ * @see SLEEF: https://sleef.org/
7
+ *
8
+ * Contains:
9
+ *
10
+ * - Sine and Cosine approximations: fast for `f32` vs accurate for `f64`
11
+ * - Tangent and the 2-argument arctangent: fast for `f32` vs accurate for `f64`
12
+ *
13
+ * For dtypes:
14
+ *
15
+ * - 64-bit IEEE-754 floating point
16
+ * - 32-bit IEEE-754 floating point
17
+ * - 16-bit IEEE-754 floating point
18
+ *
19
+ * For hardware architectures:
20
+ *
21
+ * - Arm: NEON
22
+ * - x86: Haswell, Skylake, Sapphire Rapids
23
+ *
24
+ * Those functions partially complement the `each.h` module, and are necessary for
25
+ * the `geospatial.h` module, among others. Both Haversine and Vincenty's formulas require
26
+ * trigonometric functions, and those are the most expensive part of the computation.
27
+ *
28
+ * @section glibc_math GLibC IEEE-754-compliant Math Functions
29
+ *
30
+ * The GNU C Library (GLibC) provides a set of IEEE-754-compliant math functions, like `sinf`, `cosf`,
31
+ * and double-precision variants `sin`, `cos`. Those functions are accurate to ~0.55 ULP (units in the
32
+ * last place), but can be slow to evaluate. They use a combination of techniques, like:
33
+ *
34
+ * - Taylor series expansions for small values.
35
+ * - Table lookups combined with corrections for moderate values.
36
+ * - Accurate modulo reduction for large values.
37
+ *
38
+ * The precomputed tables may be the hardest part to accelerate with SIMD, as they contain 440x values,
39
+ * each 64-bit wide.
40
+ *
41
+ * https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/branred.c#L54
42
+ * https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/s_sin.c#L84
43
+ *
44
+ * @section approximation_algorithms Approximation Algorithms
45
+ *
46
+ * There are several ways to approximate trigonometric functions, and the choice depends on the
47
+ * target hardware and the desired precision. Notably:
48
+ *
49
+ * - Taylor Series approximation is a series expansion of a sum of its derivatives at a target point.
50
+ * It's easy to derive for differentiable functions, works well for functions smooth around the
51
+ * expsansion point, but can perform poorly for functions with singularities or high-frequency
52
+ * oscillations.
53
+ *
54
+ * - Pade approximations are rational functions that approximate a function by a ratio of polynomials.
55
+ * It often converges faster than Taylor for functions with singularities or steep changes, provides
56
+ * good approximations for both smooth and rational functions, but can be more computationally
57
+ * intensive to evaluate, and can have holes (undefined points).
58
+ *
59
+ * Moreover, most approximations can be combined with Horner's methods of evaluating polynomials
60
+ * to reduce the number of multiplications and additions, and to improve the numerical stability.
61
+ * In trigonometry, the Payne-Hanek Range Reduction is another technique used to reduce the argument
62
+ * to a smaller range, where the approximation is more accurate.
63
+ *
64
+ * @section optimization_notes Optimization Notes
65
+ *
66
+ * The following optimizations were evaluated but did not yield performance improvements:
67
+ *
68
+ * - Estrin's scheme for polynomial evaluation: This tree-based approach reduces the dependency depth
69
+ * from N sequential FMAs to log2(N) by computing powers of x in parallel with partial sums.
70
+ * For an 8-term polynomial, Estrin reduces depth from 7 to 3. However, benchmarks showed ~20%
71
+ * regression because the extra MUL operations for computing x², x⁴, x⁸ hurt throughput more
72
+ * than the reduced dependency depth helps latency. For large arrays, out-of-order execution
73
+ * across loop iterations already hides FMA latency, making throughput the bottleneck.
74
+ *
75
+ * - RCPPS with Newton-Raphson refinement: Fast reciprocal approximation (~4 cycles) with one
76
+ * refinement iteration for ~22-bit precision, tested as an alternative to VDIVPS (~11 cycles).
77
+ * Did not improve performance when combined with Estrin's scheme, likely because the division
78
+ * is not on the critical path when processing large arrays.
79
+ *
80
+ * @section x86_instructions Relevant x86 Instructions
81
+ *
82
+ * Polynomial evaluation (Horner's method) for sin/cos/tan uses chained FMAs - the 4-cycle latency
83
+ * is hidden by out-of-order execution across iterations. Range reduction uses VRNDSCALE for fast
84
+ * rounding (notably 3x faster on Genoa than Ice Lake). VFPCLASS detects NaN/Inf inputs for special
85
+ * case handling. Division appears in tangent's final step but isn't on the critical path.
86
+ *
87
+ * Intrinsic Instruction Ice Genoa
88
+ * _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8) 8c @ p0 3c @ p23
89
+ * _mm512_roundscale_pd VRNDSCALEPD (ZMM, ZMM, I8) 8c @ p0 3c @ p23
90
+ * _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3c @ p5 5c @ p01
91
+ * _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4c @ p0 4c @ p01
92
+ * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
93
+ * _mm256_div_ps VDIVPS (YMM, YMM, YMM) ~14c @ p0 ~11c @ p01
94
+ * _mm256_div_pd VDIVPD (YMM, YMM, YMM) ~23c @ p0 ~13c @ p01
95
+ *
96
+ * @section arm_instructions Relevant ARM NEON/SVE Instructions
97
+ *
98
+ * ARM implementations use the same Horner polynomial approach with FMLA chains. FRINTA provides
99
+ * fast rounding for range reduction. The 4-cycle FMA latency with 4 inst/cycle throughput allows
100
+ * excellent pipelining when processing multiple elements.
101
+ *
102
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
103
+ * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
104
+ * vfmaq_f64 FMLA.D (vec) 4c @ V0123 4c @ V0123 4c @ V0123
105
+ * vrndaq_f32 FRINTA.S 2c @ V0123 2c @ V01 2c @ V01
106
+ *
107
+ * @section references References
108
+ *
109
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
110
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
111
+ *
112
+ */
113
+ #ifndef NK_TRIGONOMETRY_H
114
+ #define NK_TRIGONOMETRY_H
115
+
116
+ #include "numkong/types.h"
117
+
118
+ #if defined(__cplusplus)
119
+ extern "C" {
120
+ #endif
121
+
122
+ /**
123
+ * @brief Element-wise sine over f64 inputs in radians.
124
+ *
125
+ * @param[in] ins Input array of angles in radians.
126
+ * @param[in] n Number of elements in the input/output arrays.
127
+ * @param[out] outs Output array of sine values.
128
+ */
129
+ NK_DYNAMIC void nk_each_sin_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
130
+
131
+ /**
132
+ * @brief Element-wise cosine over f64 inputs in radians.
133
+ *
134
+ * @param[in] ins Input array of angles in radians.
135
+ * @param[in] n Number of elements in the input/output arrays.
136
+ * @param[out] outs Output array of cosine values.
137
+ */
138
+ NK_DYNAMIC void nk_each_cos_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
139
+
140
+ /**
141
+ * @brief Element-wise arc-tangent over f64 inputs.
142
+ *
143
+ * @param[in] ins Input array of input values.
144
+ * @param[in] n Number of elements in the input/output arrays.
145
+ * @param[out] outs Output array of arc-tangent values.
146
+ */
147
+ NK_DYNAMIC void nk_each_atan_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
148
+
149
+ /**
150
+ * @brief Element-wise sine over f32 inputs in radians.
151
+ *
152
+ * @param[in] ins Input array of angles in radians.
153
+ * @param[in] n Number of elements in the input/output arrays.
154
+ * @param[out] outs Output array of sine values.
155
+ */
156
+ NK_DYNAMIC void nk_each_sin_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
157
+
158
+ /**
159
+ * @brief Element-wise cosine over f32 inputs in radians.
160
+ *
161
+ * @param[in] ins Input array of angles in radians.
162
+ * @param[in] n Number of elements in the input/output arrays.
163
+ * @param[out] outs Output array of cosine values.
164
+ */
165
+ NK_DYNAMIC void nk_each_cos_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
166
+
167
+ /**
168
+ * @brief Element-wise arc-tangent over f32 inputs.
169
+ *
170
+ * @param[in] ins Input array of input values.
171
+ * @param[in] n Number of elements in the input/output arrays.
172
+ * @param[out] outs Output array of arc-tangent values.
173
+ */
174
+ NK_DYNAMIC void nk_each_atan_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
175
+
176
+ /**
177
+ * @brief Element-wise sine over f16 inputs in radians.
178
+ *
179
+ * @param[in] ins Input array of angles in radians.
180
+ * @param[in] n Number of elements in the input/output arrays.
181
+ * @param[out] outs Output array of sine values.
182
+ */
183
+ NK_DYNAMIC void nk_each_sin_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
184
+
185
+ /**
186
+ * @brief Element-wise cosine over f16 inputs in radians.
187
+ *
188
+ * @param[in] ins Input array of angles in radians.
189
+ * @param[in] n Number of elements in the input/output arrays.
190
+ * @param[out] outs Output array of cosine values.
191
+ */
192
+ NK_DYNAMIC void nk_each_cos_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
193
+
194
+ /**
195
+ * @brief Element-wise arc-tangent over f16 inputs.
196
+ *
197
+ * @param[in] ins Input array of input values.
198
+ * @param[in] n Number of elements in the input/output arrays.
199
+ * @param[out] outs Output array of arc-tangent values.
200
+ */
201
+ NK_DYNAMIC void nk_each_atan_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
202
+
203
+ /** @copydoc nk_each_sin_f64 */
204
+ NK_PUBLIC void nk_each_sin_f64_serial(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
205
+ /** @copydoc nk_each_cos_f64 */
206
+ NK_PUBLIC void nk_each_cos_f64_serial(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
207
+ /** @copydoc nk_each_atan_f64 */
208
+ NK_PUBLIC void nk_each_atan_f64_serial(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
209
+ /** @copydoc nk_each_sin_f32 */
210
+ NK_PUBLIC void nk_each_sin_f32_serial(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
211
+ /** @copydoc nk_each_cos_f32 */
212
+ NK_PUBLIC void nk_each_cos_f32_serial(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
213
+ /** @copydoc nk_each_atan_f32 */
214
+ NK_PUBLIC void nk_each_atan_f32_serial(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
215
+ /** @copydoc nk_each_sin_f16 */
216
+ NK_PUBLIC void nk_each_sin_f16_serial(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
217
+ /** @copydoc nk_each_cos_f16 */
218
+ NK_PUBLIC void nk_each_cos_f16_serial(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
219
+ /** @copydoc nk_each_atan_f16 */
220
+ NK_PUBLIC void nk_each_atan_f16_serial(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
221
+
222
+ #if NK_TARGET_NEON
223
+ /** @copydoc nk_each_sin_f64 */
224
+ NK_PUBLIC void nk_each_sin_f64_neon(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
225
+ /** @copydoc nk_each_cos_f64 */
226
+ NK_PUBLIC void nk_each_cos_f64_neon(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
227
+ /** @copydoc nk_each_atan_f64 */
228
+ NK_PUBLIC void nk_each_atan_f64_neon(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
229
+ /** @copydoc nk_each_sin_f32 */
230
+ NK_PUBLIC void nk_each_sin_f32_neon(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
231
+ /** @copydoc nk_each_cos_f32 */
232
+ NK_PUBLIC void nk_each_cos_f32_neon(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
233
+ /** @copydoc nk_each_atan_f32 */
234
+ NK_PUBLIC void nk_each_atan_f32_neon(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
235
+ #endif // NK_TARGET_NEON
236
+
237
+ /* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words.
238
+ * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420.
239
+ * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms.
240
+ * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are
241
+ * properly vectorized by recent compilers.
242
+ */
243
+ #if NK_TARGET_HASWELL
244
+ /** @copydoc nk_each_sin_f64 */
245
+ NK_PUBLIC void nk_each_sin_f64_haswell(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
246
+ /** @copydoc nk_each_cos_f64 */
247
+ NK_PUBLIC void nk_each_cos_f64_haswell(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
248
+ /** @copydoc nk_each_atan_f64 */
249
+ NK_PUBLIC void nk_each_atan_f64_haswell(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
250
+ /** @copydoc nk_each_sin_f32 */
251
+ NK_PUBLIC void nk_each_sin_f32_haswell(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
252
+ /** @copydoc nk_each_cos_f32 */
253
+ NK_PUBLIC void nk_each_cos_f32_haswell(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
254
+ /** @copydoc nk_each_atan_f32 */
255
+ NK_PUBLIC void nk_each_atan_f32_haswell(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
256
+ #endif // NK_TARGET_HASWELL
257
+
258
+ /* SIMD-powered backends for various generations of AVX512 CPUs.
259
+ * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop.
260
+ */
261
+ #if NK_TARGET_SKYLAKE
262
+ /** @copydoc nk_each_sin_f64 */
263
+ NK_PUBLIC void nk_each_sin_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
264
+ /** @copydoc nk_each_cos_f64 */
265
+ NK_PUBLIC void nk_each_cos_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
266
+ /** @copydoc nk_each_atan_f64 */
267
+ NK_PUBLIC void nk_each_atan_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
268
+ /** @copydoc nk_each_sin_f32 */
269
+ NK_PUBLIC void nk_each_sin_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
270
+ /** @copydoc nk_each_cos_f32 */
271
+ NK_PUBLIC void nk_each_cos_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
272
+ /** @copydoc nk_each_atan_f32 */
273
+ NK_PUBLIC void nk_each_atan_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
274
+ /** @copydoc nk_each_sin_f16 */
275
+ NK_PUBLIC void nk_each_sin_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
276
+ /** @copydoc nk_each_cos_f16 */
277
+ NK_PUBLIC void nk_each_cos_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
278
+ /** @copydoc nk_each_atan_f16 */
279
+ NK_PUBLIC void nk_each_atan_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
280
+ #endif // NK_TARGET_SKYLAKE
281
+
282
+ #if NK_TARGET_V128RELAXED
283
+ /** @copydoc nk_each_sin_f64 */
284
+ NK_PUBLIC void nk_each_sin_f64_v128relaxed(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
285
+ /** @copydoc nk_each_cos_f64 */
286
+ NK_PUBLIC void nk_each_cos_f64_v128relaxed(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
287
+ /** @copydoc nk_each_atan_f64 */
288
+ NK_PUBLIC void nk_each_atan_f64_v128relaxed(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
289
+ /** @copydoc nk_each_sin_f32 */
290
+ NK_PUBLIC void nk_each_sin_f32_v128relaxed(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
291
+ /** @copydoc nk_each_cos_f32 */
292
+ NK_PUBLIC void nk_each_cos_f32_v128relaxed(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
293
+ /** @copydoc nk_each_atan_f32 */
294
+ NK_PUBLIC void nk_each_atan_f32_v128relaxed(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
295
+ #endif // NK_TARGET_V128RELAXED
296
+
297
+ #if NK_TARGET_RVV
298
+ /** @copydoc nk_each_sin_f64 */
299
+ NK_PUBLIC void nk_each_sin_f64_rvv(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
300
+ /** @copydoc nk_each_cos_f64 */
301
+ NK_PUBLIC void nk_each_cos_f64_rvv(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
302
+ /** @copydoc nk_each_atan_f64 */
303
+ NK_PUBLIC void nk_each_atan_f64_rvv(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs);
304
+ /** @copydoc nk_each_sin_f32 */
305
+ NK_PUBLIC void nk_each_sin_f32_rvv(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
306
+ /** @copydoc nk_each_cos_f32 */
307
+ NK_PUBLIC void nk_each_cos_f32_rvv(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
308
+ /** @copydoc nk_each_atan_f32 */
309
+ NK_PUBLIC void nk_each_atan_f32_rvv(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs);
310
+ /** @copydoc nk_each_sin_f16 */
311
+ NK_PUBLIC void nk_each_sin_f16_rvv(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
312
+ /** @copydoc nk_each_cos_f16 */
313
+ NK_PUBLIC void nk_each_cos_f16_rvv(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
314
+ /** @copydoc nk_each_atan_f16 */
315
+ NK_PUBLIC void nk_each_atan_f16_rvv(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs);
316
+ #endif // NK_TARGET_RVV
317
+
318
+ #if defined(__cplusplus)
319
+ } // extern "C"
320
+ #endif
321
+
322
+ #include "numkong/trigonometry/serial.h"
323
+ #include "numkong/trigonometry/neon.h"
324
+ #include "numkong/trigonometry/haswell.h"
325
+ #include "numkong/trigonometry/skylake.h"
326
+ #include "numkong/trigonometry/v128relaxed.h"
327
+ #include "numkong/trigonometry/rvv.h"
328
+
329
+ #if defined(__cplusplus)
330
+ extern "C" {
331
+ #endif
332
+
333
+ #if !NK_DYNAMIC_DISPATCH
334
+
335
+ NK_PUBLIC void nk_each_sin_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
336
+ #if NK_TARGET_NEON
337
+ nk_each_sin_f64_neon(ins, n, outs);
338
+ #elif NK_TARGET_SKYLAKE
339
+ nk_each_sin_f64_skylake(ins, n, outs);
340
+ #elif NK_TARGET_HASWELL
341
+ nk_each_sin_f64_haswell(ins, n, outs);
342
+ #elif NK_TARGET_V128RELAXED
343
+ nk_each_sin_f64_v128relaxed(ins, n, outs);
344
+ #elif NK_TARGET_RVV
345
+ nk_each_sin_f64_rvv(ins, n, outs);
346
+ #else
347
+ nk_each_sin_f64_serial(ins, n, outs);
348
+ #endif
349
+ }
350
+
351
+ NK_PUBLIC void nk_each_cos_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
352
+ #if NK_TARGET_NEON
353
+ nk_each_cos_f64_neon(ins, n, outs);
354
+ #elif NK_TARGET_SKYLAKE
355
+ nk_each_cos_f64_skylake(ins, n, outs);
356
+ #elif NK_TARGET_HASWELL
357
+ nk_each_cos_f64_haswell(ins, n, outs);
358
+ #elif NK_TARGET_V128RELAXED
359
+ nk_each_cos_f64_v128relaxed(ins, n, outs);
360
+ #elif NK_TARGET_RVV
361
+ nk_each_cos_f64_rvv(ins, n, outs);
362
+ #else
363
+ nk_each_cos_f64_serial(ins, n, outs);
364
+ #endif
365
+ }
366
+
367
+ NK_PUBLIC void nk_each_atan_f64(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
368
+ #if NK_TARGET_NEON
369
+ nk_each_atan_f64_neon(ins, n, outs);
370
+ #elif NK_TARGET_SKYLAKE
371
+ nk_each_atan_f64_skylake(ins, n, outs);
372
+ #elif NK_TARGET_HASWELL
373
+ nk_each_atan_f64_haswell(ins, n, outs);
374
+ #elif NK_TARGET_V128RELAXED
375
+ nk_each_atan_f64_v128relaxed(ins, n, outs);
376
+ #elif NK_TARGET_RVV
377
+ nk_each_atan_f64_rvv(ins, n, outs);
378
+ #else
379
+ nk_each_atan_f64_serial(ins, n, outs);
380
+ #endif
381
+ }
382
+
383
+ NK_PUBLIC void nk_each_sin_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
384
+ #if NK_TARGET_NEON
385
+ nk_each_sin_f32_neon(ins, n, outs);
386
+ #elif NK_TARGET_SKYLAKE
387
+ nk_each_sin_f32_skylake(ins, n, outs);
388
+ #elif NK_TARGET_HASWELL
389
+ nk_each_sin_f32_haswell(ins, n, outs);
390
+ #elif NK_TARGET_V128RELAXED
391
+ nk_each_sin_f32_v128relaxed(ins, n, outs);
392
+ #elif NK_TARGET_RVV
393
+ nk_each_sin_f32_rvv(ins, n, outs);
394
+ #else
395
+ nk_each_sin_f32_serial(ins, n, outs);
396
+ #endif
397
+ }
398
+
399
+ NK_PUBLIC void nk_each_cos_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
400
+ #if NK_TARGET_NEON
401
+ nk_each_cos_f32_neon(ins, n, outs);
402
+ #elif NK_TARGET_SKYLAKE
403
+ nk_each_cos_f32_skylake(ins, n, outs);
404
+ #elif NK_TARGET_HASWELL
405
+ nk_each_cos_f32_haswell(ins, n, outs);
406
+ #elif NK_TARGET_V128RELAXED
407
+ nk_each_cos_f32_v128relaxed(ins, n, outs);
408
+ #elif NK_TARGET_RVV
409
+ nk_each_cos_f32_rvv(ins, n, outs);
410
+ #else
411
+ nk_each_cos_f32_serial(ins, n, outs);
412
+ #endif
413
+ }
414
+
415
+ NK_PUBLIC void nk_each_atan_f32(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
416
+ #if NK_TARGET_NEON
417
+ nk_each_atan_f32_neon(ins, n, outs);
418
+ #elif NK_TARGET_SKYLAKE
419
+ nk_each_atan_f32_skylake(ins, n, outs);
420
+ #elif NK_TARGET_HASWELL
421
+ nk_each_atan_f32_haswell(ins, n, outs);
422
+ #elif NK_TARGET_V128RELAXED
423
+ nk_each_atan_f32_v128relaxed(ins, n, outs);
424
+ #elif NK_TARGET_RVV
425
+ nk_each_atan_f32_rvv(ins, n, outs);
426
+ #else
427
+ nk_each_atan_f32_serial(ins, n, outs);
428
+ #endif
429
+ }
430
+
431
+ NK_PUBLIC void nk_each_sin_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
432
+ #if NK_TARGET_SKYLAKE
433
+ nk_each_sin_f16_skylake(ins, n, outs);
434
+ #elif NK_TARGET_RVV
435
+ nk_each_sin_f16_rvv(ins, n, outs);
436
+ #else
437
+ nk_each_sin_f16_serial(ins, n, outs);
438
+ #endif
439
+ }
440
+
441
+ NK_PUBLIC void nk_each_cos_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
442
+ #if NK_TARGET_SKYLAKE
443
+ nk_each_cos_f16_skylake(ins, n, outs);
444
+ #elif NK_TARGET_RVV
445
+ nk_each_cos_f16_rvv(ins, n, outs);
446
+ #else
447
+ nk_each_cos_f16_serial(ins, n, outs);
448
+ #endif
449
+ }
450
+
451
+ NK_PUBLIC void nk_each_atan_f16(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
452
+ #if NK_TARGET_SKYLAKE
453
+ nk_each_atan_f16_skylake(ins, n, outs);
454
+ #elif NK_TARGET_RVV
455
+ nk_each_atan_f16_rvv(ins, n, outs);
456
+ #else
457
+ nk_each_atan_f16_serial(ins, n, outs);
458
+ #endif
459
+ }
460
+
461
+ #endif // !NK_DYNAMIC_DISPATCH
462
+
463
+ #if defined(__cplusplus)
464
+ } // extern "C"
465
+ #endif
466
+
467
+ #endif // NK_TRIGONOMETRY_H
@@ -0,0 +1,166 @@
1
+ /**
2
+ * @brief C++ bindings for trigonometric kernels.
3
+ * @file include/numkong/trigonometry.hpp
4
+ * @author Ash Vardanian
5
+ * @date February 5, 2026
6
+ */
7
+ #ifndef NK_TRIGONOMETRY_HPP
8
+ #define NK_TRIGONOMETRY_HPP
9
+
10
+ #include <cstdint>
11
+ #include <type_traits>
12
+
13
+ #include "numkong/trigonometry.h"
14
+
15
+ #include "numkong/types.hpp"
16
+
17
+ namespace ashvardanian::numkong {
18
+
19
+ /**
20
+ * @brief Array sine: outᵢ = sin(inᵢ)
21
+ * @param[in] in Input array
22
+ * @param[in] n Number of elements
23
+ * @param[out] out Output array
24
+ *
25
+ * @tparam in_type_ Element type (f32_t, f64_t, f16_t)
26
+ * @tparam precision_type_ Precision type for scalar fallback, defaults to `in_type_`
27
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
28
+ */
29
+ template <numeric_dtype in_type_, numeric_dtype precision_type_ = in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
30
+ void sin(in_type_ const *in, std::size_t n, in_type_ *out) noexcept {
31
+ constexpr bool simd = allow_simd_ == prefer_simd_k && std::is_same_v<in_type_, precision_type_>;
32
+
33
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_each_sin_f64(&in->raw_, n, &out->raw_);
34
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_each_sin_f32(&in->raw_, n, &out->raw_);
35
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_each_sin_f16(&in->raw_, n, &out->raw_);
36
+ // Scalar fallback
37
+ else {
38
+ for (std::size_t i = 0; i < n; i++) out[i] = in_type_(precision_type_(in[i]).sin());
39
+ }
40
+ }
41
+
42
+ /**
43
+ * @brief Array cosine: outᵢ = cos(inᵢ)
44
+ * @param[in] in Input array
45
+ * @param[in] n Number of elements
46
+ * @param[out] out Output array
47
+ *
48
+ * @tparam in_type_ Element type (f32_t, f64_t, f16_t)
49
+ * @tparam precision_type_ Precision type for scalar fallback, defaults to `in_type_`
50
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
51
+ */
52
+ template <numeric_dtype in_type_, numeric_dtype precision_type_ = in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
53
+ void cos(in_type_ const *in, std::size_t n, in_type_ *out) noexcept {
54
+ constexpr bool simd = allow_simd_ == prefer_simd_k && std::is_same_v<in_type_, precision_type_>;
55
+
56
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_each_cos_f64(&in->raw_, n, &out->raw_);
57
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_each_cos_f32(&in->raw_, n, &out->raw_);
58
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_each_cos_f16(&in->raw_, n, &out->raw_);
59
+ // Scalar fallback
60
+ else {
61
+ for (std::size_t i = 0; i < n; i++) out[i] = in_type_(precision_type_(in[i]).cos());
62
+ }
63
+ }
64
+
65
+ /**
66
+ * @brief Array arctangent: outᵢ = arctan(inᵢ)
67
+ * @param[in] in Input array
68
+ * @param[in] n Number of elements
69
+ * @param[out] out Output array
70
+ *
71
+ * @tparam in_type_ Element type (f32_t, f64_t, f16_t)
72
+ * @tparam precision_type_ Precision type for scalar fallback, defaults to `in_type_`
73
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
74
+ */
75
+ template <numeric_dtype in_type_, numeric_dtype precision_type_ = in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
76
+ void atan(in_type_ const *in, std::size_t n, in_type_ *out) noexcept {
77
+ constexpr bool simd = allow_simd_ == prefer_simd_k && std::is_same_v<in_type_, precision_type_>;
78
+
79
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_each_atan_f64(&in->raw_, n, &out->raw_);
80
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd) nk_each_atan_f32(&in->raw_, n, &out->raw_);
81
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd) nk_each_atan_f16(&in->raw_, n, &out->raw_);
82
+ // Scalar fallback
83
+ else {
84
+ for (std::size_t i = 0; i < n; i++) out[i] = in_type_(precision_type_(in[i]).atan());
85
+ }
86
+ }
87
+
88
+ } // namespace ashvardanian::numkong
89
+
90
+ #include "numkong/tensor.hpp"
91
+
92
+ namespace ashvardanian::numkong {
93
+
94
+ #pragma region - Tensor Trigonometric
95
+
96
+ /** @brief Elementwise sin into pre-allocated output. */
97
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
98
+ bool sin(tensor_view<value_type_, max_rank_> input, tensor_span<value_type_, max_rank_> output) noexcept {
99
+ return elementwise_into_<value_type_, max_rank_>(
100
+ input, output, [](tensor_view<value_type_, max_rank_> in, tensor_span<value_type_, max_rank_> out) {
101
+ numkong::sin<value_type_>(in.data(), in.extent(0), out.data());
102
+ });
103
+ }
104
+
105
+ /** @brief Allocating sin. */
106
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
107
+ typename allocator_type_ = aligned_allocator<value_type_>>
108
+ tensor<value_type_, allocator_type_, max_rank_> try_sin(tensor_view<value_type_, max_rank_> input) noexcept {
109
+ using out_tensor_t = tensor<value_type_, allocator_type_, max_rank_>;
110
+ if (input.empty()) return out_tensor_t {};
111
+ auto &input_shape = input.shape();
112
+ auto result = out_tensor_t::try_empty(input_shape.extents, input_shape.rank);
113
+ if (result.empty()) return result;
114
+ if (!sin<value_type_, max_rank_>(input, result.span())) return out_tensor_t {};
115
+ return result;
116
+ }
117
+
118
+ /** @brief Elementwise cos into pre-allocated output. */
119
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
120
+ bool cos(tensor_view<value_type_, max_rank_> input, tensor_span<value_type_, max_rank_> output) noexcept {
121
+ return elementwise_into_<value_type_, max_rank_>(
122
+ input, output, [](tensor_view<value_type_, max_rank_> in, tensor_span<value_type_, max_rank_> out) {
123
+ numkong::cos<value_type_>(in.data(), in.extent(0), out.data());
124
+ });
125
+ }
126
+
127
+ /** @brief Allocating cos. */
128
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
129
+ typename allocator_type_ = aligned_allocator<value_type_>>
130
+ tensor<value_type_, allocator_type_, max_rank_> try_cos(tensor_view<value_type_, max_rank_> input) noexcept {
131
+ using out_tensor_t = tensor<value_type_, allocator_type_, max_rank_>;
132
+ if (input.empty()) return out_tensor_t {};
133
+ auto &input_shape = input.shape();
134
+ auto result = out_tensor_t::try_empty(input_shape.extents, input_shape.rank);
135
+ if (result.empty()) return result;
136
+ if (!cos<value_type_, max_rank_>(input, result.span())) return out_tensor_t {};
137
+ return result;
138
+ }
139
+
140
+ /** @brief Elementwise atan into pre-allocated output. */
141
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
142
+ bool atan(tensor_view<value_type_, max_rank_> input, tensor_span<value_type_, max_rank_> output) noexcept {
143
+ return elementwise_into_<value_type_, max_rank_>(
144
+ input, output, [](tensor_view<value_type_, max_rank_> in, tensor_span<value_type_, max_rank_> out) {
145
+ numkong::atan<value_type_>(in.data(), in.extent(0), out.data());
146
+ });
147
+ }
148
+
149
+ /** @brief Allocating atan. */
150
+ template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
151
+ typename allocator_type_ = aligned_allocator<value_type_>>
152
+ tensor<value_type_, allocator_type_, max_rank_> try_atan(tensor_view<value_type_, max_rank_> input) noexcept {
153
+ using out_tensor_t = tensor<value_type_, allocator_type_, max_rank_>;
154
+ if (input.empty()) return out_tensor_t {};
155
+ auto &input_shape = input.shape();
156
+ auto result = out_tensor_t::try_empty(input_shape.extents, input_shape.rank);
157
+ if (result.empty()) return result;
158
+ if (!atan<value_type_, max_rank_>(input, result.span())) return out_tensor_t {};
159
+ return result;
160
+ }
161
+
162
+ #pragma endregion - Tensor Trigonometric
163
+
164
+ } // namespace ashvardanian::numkong
165
+
166
+ #endif // NK_TRIGONOMETRY_HPP