numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,721 @@
1
+ /**
2
+ * @brief SIMD-accelerated Trigonometric Functions for Skylake.
3
+ * @file include/numkong/trigonometry/skylake.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/trigonometry.h
8
+ * @see https://sleef.org
9
+ *
10
+ * @section skylake_trig_instructions Key AVX-512 Trigonometry Instructions
11
+ *
12
+ * Intrinsic Instruction Latency Throughput Ports
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
14
+ * _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
15
+ * _mm512_and_ps VANDPS (ZMM, ZMM, ZMM) 1cy 0.33/cy p015
16
+ * _mm512_cmp_ps_mask VCMPPS (K, ZMM, ZMM, I8) 3cy 1/cy p01
17
+ * _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8) 8cy 0.5/cy p01
18
+ *
19
+ * Trigonometric functions use polynomial approximations evaluated via Horner's method with FMA chains.
20
+ * AVX-512 mask registers enable branchless range reduction and sign handling without blend overhead.
21
+ * Skylake-X's dual FMA units achieve 0.5cy throughput, processing 32 f32 sin/cos values per 8 cycles.
22
+ */
23
+ #ifndef NK_TRIGONOMETRY_SKYLAKE_H
24
+ #define NK_TRIGONOMETRY_SKYLAKE_H
25
+
26
+ #if NK_TARGET_X86_
27
+ #if NK_TARGET_SKYLAKE
28
+
29
+ #include "numkong/types.h"
30
+
31
+ #if defined(__cplusplus)
32
+ extern "C" {
33
+ #endif
34
+
35
+ #if defined(__clang__)
36
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
37
+ apply_to = function)
38
+ #elif defined(__GNUC__)
39
+ #pragma GCC push_options
40
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
41
+ #endif
42
+
43
+ NK_INTERNAL __m512 nk_sin_f32x16_skylake_(__m512 const angles_radians) {
44
+ // Cody-Waite constants for argument reduction
45
+ __m512 const pi_hi_f32x16 = _mm512_set1_ps(3.1415927f);
46
+ __m512 const pi_lo_f32x16 = _mm512_set1_ps(-8.742278e-8f);
47
+ __m512 const pi_reciprocal = _mm512_set1_ps(0.31830988618379067154f); // 1/π
48
+ // Degree-9 minimax coefficients
49
+ __m512 const coeff_9 = _mm512_set1_ps(+2.7557319224e-6f);
50
+ __m512 const coeff_7 = _mm512_set1_ps(-1.9841269841e-4f);
51
+ __m512 const coeff_5 = _mm512_set1_ps(+8.3333293855e-3f);
52
+ __m512 const coeff_3 = _mm512_set1_ps(-1.6666666641e-1f);
53
+
54
+ // Compute (multiples_of_pi) = round(angle / π)
55
+ __m512 quotients = _mm512_mul_ps(angles_radians, pi_reciprocal);
56
+ __m512 rounded_quotients = _mm512_roundscale_ps(quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
57
+ // Use explicit rounding to match roundscale (MXCSR-independent)
58
+ __m512i multiples_of_pi = _mm512_cvt_roundps_epi32(rounded_quotients,
59
+ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
60
+
61
+ // Cody-Waite range reduction
62
+ __m512 angles = _mm512_fnmadd_ps(rounded_quotients, pi_hi_f32x16, angles_radians);
63
+ angles = _mm512_fnmadd_ps(rounded_quotients, pi_lo_f32x16, angles);
64
+ __m512 const angles_squared = _mm512_mul_ps(angles, angles);
65
+ __m512 const angles_cubed = _mm512_mul_ps(angles, angles_squared);
66
+
67
+ // Degree-9 polynomial via Horner's method
68
+ __m512 polynomials = coeff_9;
69
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_7);
70
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_5);
71
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_3);
72
+
73
+ // If multiples_of_pi is odd, flip the sign of the results
74
+ __mmask16 odd_mask = _mm512_test_epi32_mask(multiples_of_pi, _mm512_set1_epi32(1));
75
+ __m512 results = _mm512_fmadd_ps(angles_cubed, polynomials, angles);
76
+ results = _mm512_mask_sub_ps(results, odd_mask, _mm512_setzero_ps(), results);
77
+ return results;
78
+ }
79
+
80
+ NK_INTERNAL __m512 nk_cos_f32x16_skylake_(__m512 const angles_radians) {
81
+ // Cody-Waite constants for argument reduction
82
+ __m512 const pi_hi_f32x16 = _mm512_set1_ps(3.1415927f);
83
+ __m512 const pi_lo_f32x16 = _mm512_set1_ps(-8.742278e-8f);
84
+ __m512 const pi_half = _mm512_set1_ps(1.57079632679489661923f); // π/2
85
+ __m512 const pi_reciprocal = _mm512_set1_ps(0.31830988618379067154f); // 1/π
86
+ // Degree-9 minimax coefficients
87
+ __m512 const coeff_9 = _mm512_set1_ps(+2.7557319224e-6f);
88
+ __m512 const coeff_7 = _mm512_set1_ps(-1.9841269841e-4f);
89
+ __m512 const coeff_5 = _mm512_set1_ps(+8.3333293855e-3f);
90
+ __m512 const coeff_3 = _mm512_set1_ps(-1.6666666641e-1f);
91
+
92
+ // Compute (multiples_of_pi) = round((angle / π) - 0.5)
93
+ __m512 quotients = _mm512_fmsub_ps(angles_radians, pi_reciprocal, _mm512_set1_ps(0.5f));
94
+ __m512 rounded_quotients = _mm512_roundscale_ps(quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
95
+ // Use explicit rounding to match roundscale (MXCSR-independent)
96
+ __m512i multiples_of_pi = _mm512_cvt_roundps_epi32(rounded_quotients,
97
+ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
98
+
99
+ // Cody-Waite range reduction: angle = angle_radians - (multiples * pi + pi/2)
100
+ __m512 const offset = _mm512_fmadd_ps(rounded_quotients, pi_hi_f32x16, pi_half);
101
+ __m512 angles = _mm512_sub_ps(angles_radians, offset);
102
+ angles = _mm512_fnmadd_ps(rounded_quotients, pi_lo_f32x16, angles);
103
+ __m512 const angles_squared = _mm512_mul_ps(angles, angles);
104
+ __m512 const angles_cubed = _mm512_mul_ps(angles, angles_squared);
105
+
106
+ // Degree-9 polynomial via Horner's method
107
+ __m512 polynomials = coeff_9;
108
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_7);
109
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_5);
110
+ polynomials = _mm512_fmadd_ps(polynomials, angles_squared, coeff_3);
111
+ __m512 results = _mm512_fmadd_ps(angles_cubed, polynomials, angles);
112
+
113
+ // If multiples_of_pi is even, flip the sign of the results
114
+ __mmask16 even_mask = _mm512_testn_epi32_mask(multiples_of_pi, _mm512_set1_epi32(1));
115
+ results = _mm512_mask_sub_ps(results, even_mask, _mm512_setzero_ps(), results);
116
+ return results;
117
+ }
118
+
119
+ NK_INTERNAL __m512 nk_atan_f32x16_skylake_(__m512 const inputs) {
120
+ // Polynomial coefficients
121
+ __m512 const coeff_8 = _mm512_set1_ps(-0.333331018686294555664062f);
122
+ __m512 const coeff_7 = _mm512_set1_ps(+0.199926957488059997558594f);
123
+ __m512 const coeff_6 = _mm512_set1_ps(-0.142027363181114196777344f);
124
+ __m512 const coeff_5 = _mm512_set1_ps(+0.106347933411598205566406f);
125
+ __m512 const coeff_4 = _mm512_set1_ps(-0.0748900920152664184570312f);
126
+ __m512 const coeff_3 = _mm512_set1_ps(+0.0425049886107444763183594f);
127
+ __m512 const coeff_2 = _mm512_set1_ps(-0.0159569028764963150024414f);
128
+ __m512 const coeff_1 = _mm512_set1_ps(+0.00282363896258175373077393f);
129
+
130
+ // Adjust for quadrant
131
+ __m512 values = inputs;
132
+ __mmask16 const negative_mask = _mm512_fpclass_ps_mask(values, 0x40);
133
+ values = _mm512_abs_ps(values);
134
+ __mmask16 const reciprocal_mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(1.0f), _CMP_GT_OS);
135
+ values = _mm512_mask_div_ps(values, reciprocal_mask, _mm512_set1_ps(1.0f), values);
136
+
137
+ // Argument reduction
138
+ __m512 const values_squared = _mm512_mul_ps(values, values);
139
+ __m512 const values_cubed = _mm512_mul_ps(values, values_squared);
140
+
141
+ // Polynomial evaluation
142
+ __m512 polynomials = coeff_1;
143
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_2);
144
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_3);
145
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_4);
146
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_5);
147
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_6);
148
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_7);
149
+ polynomials = _mm512_fmadd_ps(polynomials, values_squared, coeff_8);
150
+
151
+ // Adjust result for quadrants
152
+ __m512 result = _mm512_fmadd_ps(values_cubed, polynomials, values);
153
+ result = _mm512_mask_sub_ps(result, reciprocal_mask, _mm512_set1_ps(1.5707963267948966f), result);
154
+ result = _mm512_mask_sub_ps(result, negative_mask, _mm512_setzero_ps(), result);
155
+ return result;
156
+ }
157
+
158
+ NK_INTERNAL __m512 nk_atan2_f32x16_skylake_(__m512 const ys_inputs, __m512 const xs_inputs) {
159
+ // Polynomial coefficients
160
+ __m512 const coeff_8 = _mm512_set1_ps(-0.333331018686294555664062f);
161
+ __m512 const coeff_7 = _mm512_set1_ps(+0.199926957488059997558594f);
162
+ __m512 const coeff_6 = _mm512_set1_ps(-0.142027363181114196777344f);
163
+ __m512 const coeff_5 = _mm512_set1_ps(+0.106347933411598205566406f);
164
+ __m512 const coeff_4 = _mm512_set1_ps(-0.0748900920152664184570312f);
165
+ __m512 const coeff_3 = _mm512_set1_ps(+0.0425049886107444763183594f);
166
+ __m512 const coeff_2 = _mm512_set1_ps(-0.0159569028764963150024414f);
167
+ __m512 const coeff_1 = _mm512_set1_ps(+0.00282363896258175373077393f);
168
+
169
+ // Quadrant adjustments normalizing to absolute values of x and y
170
+ __mmask16 const xs_negative_mask = _mm512_fpclass_ps_mask(xs_inputs, 0x40);
171
+ __m512 xs = _mm512_abs_ps(xs_inputs);
172
+ __m512 ys = _mm512_abs_ps(ys_inputs);
173
+ // Ensure proper fraction where the numerator is smaller than the denominator
174
+ __mmask16 const swap_mask = _mm512_cmp_ps_mask(ys, xs, _CMP_GT_OS);
175
+ __m512 temps = xs;
176
+ xs = _mm512_mask_blend_ps(swap_mask, xs, ys);
177
+ ys = _mm512_mask_sub_ps(ys, swap_mask, _mm512_setzero_ps(), temps);
178
+
179
+ // Compute ratio and ratio²
180
+ __m512 const ratio = _mm512_div_ps(ys, xs);
181
+ __m512 const ratio_squared = _mm512_mul_ps(ratio, ratio);
182
+ __m512 const ratio_cubed = _mm512_mul_ps(ratio, ratio_squared);
183
+
184
+ // Polynomial evaluation
185
+ __m512 polynomials = coeff_1;
186
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_2);
187
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_3);
188
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_4);
189
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_5);
190
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_6);
191
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_7);
192
+ polynomials = _mm512_fmadd_ps(polynomials, ratio_squared, coeff_8);
193
+
194
+ // Compute quadrant value: 0 for x>=0 && !swap, 1 for x>=0 && swap,
195
+ // -2 for x<0 && !swap, -1 for x<0 && swap
196
+ __m512 results = _mm512_fmadd_ps(ratio_cubed, polynomials, ratio);
197
+ __m512 quadrant = _mm512_setzero_ps();
198
+ __m512 neg_two = _mm512_set1_ps(-2.0f);
199
+ quadrant = _mm512_mask_blend_ps(xs_negative_mask, quadrant, neg_two);
200
+ __m512 one = _mm512_set1_ps(1.0f);
201
+ __m512 quadrant_incremented = _mm512_add_ps(quadrant, one);
202
+ quadrant = _mm512_mask_blend_ps(swap_mask, quadrant, quadrant_incremented);
203
+
204
+ // Adjust for quadrant: result += quadrant * π/2
205
+ __m512 pi_half = _mm512_set1_ps(1.5707963267948966f);
206
+ results = _mm512_fmadd_ps(quadrant, pi_half, results);
207
+
208
+ // Transfer sign from x (XOR with sign bit of x_input)
209
+ __m512 xs_sign_bits = _mm512_and_ps(xs_inputs, _mm512_set1_ps(-0.0f));
210
+ results = _mm512_xor_ps(results, xs_sign_bits);
211
+
212
+ // Transfer sign from y (XOR with sign bit of y_input)
213
+ __m512 ys_sign_bits = _mm512_and_ps(ys_inputs, _mm512_set1_ps(-0.0f));
214
+ results = _mm512_xor_ps(results, ys_sign_bits);
215
+
216
+ return results;
217
+ }
218
+
219
+ NK_PUBLIC void nk_each_sin_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
220
+ nk_size_t i = 0;
221
+ for (; i + 16 <= n; i += 16) {
222
+ __m512 angles = _mm512_loadu_ps(ins + i);
223
+ __m512 results = nk_sin_f32x16_skylake_(angles);
224
+ _mm512_storeu_ps(outs + i, results);
225
+ }
226
+ if (i < n) {
227
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
228
+ __m512 angles = _mm512_maskz_loadu_ps(mask, ins + i);
229
+ __m512 results = nk_sin_f32x16_skylake_(angles);
230
+ _mm512_mask_storeu_ps(outs + i, mask, results);
231
+ }
232
+ }
233
+ NK_PUBLIC void nk_each_cos_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
234
+ nk_size_t i = 0;
235
+ for (; i + 16 <= n; i += 16) {
236
+ __m512 angles = _mm512_loadu_ps(ins + i);
237
+ __m512 results = nk_cos_f32x16_skylake_(angles);
238
+ _mm512_storeu_ps(outs + i, results);
239
+ }
240
+ if (i < n) {
241
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
242
+ __m512 angles = _mm512_maskz_loadu_ps(mask, ins + i);
243
+ __m512 results = nk_cos_f32x16_skylake_(angles);
244
+ _mm512_mask_storeu_ps(outs + i, mask, results);
245
+ }
246
+ }
247
+ NK_PUBLIC void nk_each_atan_f32_skylake(nk_f32_t const *ins, nk_size_t n, nk_f32_t *outs) {
248
+ nk_size_t i = 0;
249
+ for (; i + 16 <= n; i += 16) {
250
+ __m512 angles = _mm512_loadu_ps(ins + i);
251
+ __m512 results = nk_atan_f32x16_skylake_(angles);
252
+ _mm512_storeu_ps(outs + i, results);
253
+ }
254
+ if (i < n) {
255
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
256
+ __m512 angles = _mm512_maskz_loadu_ps(mask, ins + i);
257
+ __m512 results = nk_atan_f32x16_skylake_(angles);
258
+ _mm512_mask_storeu_ps(outs + i, mask, results);
259
+ }
260
+ }
261
+
262
+ NK_INTERNAL __m512d nk_sin_f64x8_skylake_(__m512d const angles_radians) {
263
+ // Constants for argument reduction
264
+ __m512d const pi_high = _mm512_set1_pd(3.141592653589793116); // High-digits part of π
265
+ __m512d const pi_low = _mm512_set1_pd(1.2246467991473532072e-16); // Low-digits part of π
266
+ __m512d const pi_reciprocal = _mm512_set1_pd(0.31830988618379067154); // 1/π
267
+
268
+ // Polynomial coefficients for sine/cosine approximation (minimax polynomial)
269
+ __m512d const coeff_0 = _mm512_set1_pd(+0.00833333333333332974823815);
270
+ __m512d const coeff_1 = _mm512_set1_pd(-0.000198412698412696162806809);
271
+ __m512d const coeff_2 = _mm512_set1_pd(+2.75573192239198747630416e-06);
272
+ __m512d const coeff_3 = _mm512_set1_pd(-2.50521083763502045810755e-08);
273
+ __m512d const coeff_4 = _mm512_set1_pd(+1.60590430605664501629054e-10);
274
+ __m512d const coeff_5 = _mm512_set1_pd(-7.64712219118158833288484e-13);
275
+ __m512d const coeff_6 = _mm512_set1_pd(+2.81009972710863200091251e-15);
276
+ __m512d const coeff_7 = _mm512_set1_pd(-7.97255955009037868891952e-18);
277
+ __m512d const coeff_8 = _mm512_set1_pd(-0.166666666666666657414808);
278
+
279
+ // Compute (rounded_quotients) = round(angle / π)
280
+ __m512d const quotients = _mm512_mul_pd(angles_radians, pi_reciprocal);
281
+ __m512d const rounded_quotients = _mm512_roundscale_pd(quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
282
+
283
+ // Reduce the angle to: angle - (rounded_quotients * π_high + rounded_quotients * π_low)
284
+ __m512d angles = angles_radians;
285
+ angles = _mm512_fnmadd_pd(rounded_quotients, pi_high, angles);
286
+ angles = _mm512_fnmadd_pd(rounded_quotients, pi_low, angles);
287
+
288
+ // If rounded_quotients is odd (bit 0 set), negate the angle
289
+ // Use explicit rounding to match roundscale (MXCSR-independent)
290
+ __mmask8 const sign_flip_mask = _mm256_test_epi32_mask(
291
+ _mm512_cvt_roundpd_epi32(rounded_quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC),
292
+ _mm256_set1_epi32(1));
293
+ angles = _mm512_mask_sub_pd(angles, sign_flip_mask, _mm512_setzero_pd(), angles);
294
+
295
+ __m512d const angles_squared = _mm512_mul_pd(angles, angles);
296
+ __m512d const angles_cubed = _mm512_mul_pd(angles, angles_squared);
297
+ __m512d const angles_quadratic = _mm512_mul_pd(angles_squared, angles_squared);
298
+ __m512d const angles_octic = _mm512_mul_pd(angles_quadratic, angles_quadratic);
299
+
300
+ // Compute higher-degree polynomial terms
301
+ __m512d const poly_67 = _mm512_fmadd_pd(angles_squared, coeff_7, coeff_6);
302
+ __m512d const poly_45 = _mm512_fmadd_pd(angles_squared, coeff_5, coeff_4);
303
+ __m512d const poly_4567 = _mm512_fmadd_pd(angles_quadratic, poly_67, poly_45);
304
+
305
+ // Compute lower-degree polynomial terms
306
+ __m512d const poly_23 = _mm512_fmadd_pd(angles_squared, coeff_3, coeff_2);
307
+ __m512d const poly_01 = _mm512_fmadd_pd(angles_squared, coeff_1, coeff_0);
308
+ __m512d const poly_0123 = _mm512_fmadd_pd(angles_quadratic, poly_23, poly_01);
309
+
310
+ // Combine polynomial terms
311
+ __m512d results = _mm512_fmadd_pd(angles_octic, poly_4567, poly_0123);
312
+ results = _mm512_fmadd_pd(results, angles_squared, coeff_8);
313
+ results = _mm512_fmadd_pd(results, angles_cubed, angles);
314
+
315
+ // Handle the special case of negative zero input
316
+ __mmask8 const non_zero_mask = _mm512_cmpneq_pd_mask(angles_radians, _mm512_setzero_pd());
317
+ results = _mm512_maskz_mov_pd(non_zero_mask, results);
318
+ return results;
319
+ }
320
+
321
+ NK_INTERNAL __m512d nk_cos_f64x8_skylake_(__m512d const angles_radians) {
322
+ // Constants for argument reduction
323
+ __m512d const pi_high_half = _mm512_set1_pd(3.141592653589793116 * 0.5); // High-digits part of π
324
+ __m512d const pi_low_half = _mm512_set1_pd(1.2246467991473532072e-16 * 0.5); // Low-digits part of π
325
+ __m512d const pi_reciprocal = _mm512_set1_pd(0.31830988618379067154); // 1/π
326
+
327
+ // Polynomial coefficients for sine/cosine approximation (minimax polynomial)
328
+ __m512d const coeff_0 = _mm512_set1_pd(+0.00833333333333332974823815);
329
+ __m512d const coeff_1 = _mm512_set1_pd(-0.000198412698412696162806809);
330
+ __m512d const coeff_2 = _mm512_set1_pd(+2.75573192239198747630416e-06);
331
+ __m512d const coeff_3 = _mm512_set1_pd(-2.50521083763502045810755e-08);
332
+ __m512d const coeff_4 = _mm512_set1_pd(+1.60590430605664501629054e-10);
333
+ __m512d const coeff_5 = _mm512_set1_pd(-7.64712219118158833288484e-13);
334
+ __m512d const coeff_6 = _mm512_set1_pd(+2.81009972710863200091251e-15);
335
+ __m512d const coeff_7 = _mm512_set1_pd(-7.97255955009037868891952e-18);
336
+ __m512d const coeff_8 = _mm512_set1_pd(-0.166666666666666657414808);
337
+
338
+ // Compute (rounded_quotients) = 2 * round(angle / π - 0.5) + 1
339
+ // Use fmsub: a*b - c = angles * (1/π) - 0.5
340
+ __m512d const quotients = _mm512_fmsub_pd(angles_radians, pi_reciprocal, _mm512_set1_pd(0.5));
341
+ __m512d const rounded_quotients = _mm512_fmadd_pd( //
342
+ _mm512_set1_pd(2), //
343
+ _mm512_roundscale_pd(quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC), //
344
+ _mm512_set1_pd(1));
345
+
346
+ // Reduce the angle to: angle - (rounded_quotients * π_high + rounded_quotients * π_low)
347
+ __m512d angles = angles_radians;
348
+ angles = _mm512_fnmadd_pd(rounded_quotients, pi_high_half, angles);
349
+ angles = _mm512_fnmadd_pd(rounded_quotients, pi_low_half, angles);
350
+ // Use explicit rounding to match roundscale (MXCSR-independent)
351
+ __mmask8 const sign_flip_mask = _mm256_testn_epi32_mask(
352
+ _mm512_cvt_roundpd_epi32(rounded_quotients, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC),
353
+ _mm256_set1_epi32(2));
354
+ angles = _mm512_mask_sub_pd(angles, sign_flip_mask, _mm512_setzero_pd(), angles);
355
+ __m512d const angles_squared = _mm512_mul_pd(angles, angles);
356
+ __m512d const angles_cubed = _mm512_mul_pd(angles, angles_squared);
357
+ __m512d const angles_quadratic = _mm512_mul_pd(angles_squared, angles_squared);
358
+ __m512d const angles_octic = _mm512_mul_pd(angles_quadratic, angles_quadratic);
359
+
360
+ // Compute higher-degree polynomial terms
361
+ __m512d const poly_67 = _mm512_fmadd_pd(angles_squared, coeff_7, coeff_6);
362
+ __m512d const poly_45 = _mm512_fmadd_pd(angles_squared, coeff_5, coeff_4);
363
+ __m512d const poly_4567 = _mm512_fmadd_pd(angles_quadratic, poly_67, poly_45);
364
+
365
+ // Compute lower-degree polynomial terms
366
+ __m512d const poly_23 = _mm512_fmadd_pd(angles_squared, coeff_3, coeff_2);
367
+ __m512d const poly_01 = _mm512_fmadd_pd(angles_squared, coeff_1, coeff_0);
368
+ __m512d const poly_0123 = _mm512_fmadd_pd(angles_quadratic, poly_23, poly_01);
369
+
370
+ // Combine polynomial terms
371
+ __m512d results = _mm512_fmadd_pd(angles_octic, poly_4567, poly_0123);
372
+ results = _mm512_fmadd_pd(results, angles_squared, coeff_8);
373
+ results = _mm512_fmadd_pd(results, angles_cubed, angles);
374
+ return results;
375
+ }
376
+
377
+ NK_INTERNAL __m512d nk_atan_f64x8_skylake_(__m512d const inputs) {
378
+ // Polynomial coefficients for atan approximation
379
+ __m512d const coeff_19 = _mm512_set1_pd(-1.88796008463073496563746e-05);
380
+ __m512d const coeff_18 = _mm512_set1_pd(+0.000209850076645816976906797);
381
+ __m512d const coeff_17 = _mm512_set1_pd(-0.00110611831486672482563471);
382
+ __m512d const coeff_16 = _mm512_set1_pd(+0.00370026744188713119232403);
383
+ __m512d const coeff_15 = _mm512_set1_pd(-0.00889896195887655491740809);
384
+ __m512d const coeff_14 = _mm512_set1_pd(+0.016599329773529201970117);
385
+ __m512d const coeff_13 = _mm512_set1_pd(-0.0254517624932312641616861);
386
+ __m512d const coeff_12 = _mm512_set1_pd(+0.0337852580001353069993897);
387
+ __m512d const coeff_11 = _mm512_set1_pd(-0.0407629191276836500001934);
388
+ __m512d const coeff_10 = _mm512_set1_pd(+0.0466667150077840625632675);
389
+ __m512d const coeff_9 = _mm512_set1_pd(-0.0523674852303482457616113);
390
+ __m512d const coeff_8 = _mm512_set1_pd(+0.0587666392926673580854313);
391
+ __m512d const coeff_7 = _mm512_set1_pd(-0.0666573579361080525984562);
392
+ __m512d const coeff_6 = _mm512_set1_pd(+0.0769219538311769618355029);
393
+ __m512d const coeff_5 = _mm512_set1_pd(-0.090908995008245008229153);
394
+ __m512d const coeff_4 = _mm512_set1_pd(+0.111111105648261418443745);
395
+ __m512d const coeff_3 = _mm512_set1_pd(-0.14285714266771329383765);
396
+ __m512d const coeff_2 = _mm512_set1_pd(+0.199999999996591265594148);
397
+ __m512d const coeff_1 = _mm512_set1_pd(-0.333333333333311110369124);
398
+
399
+ // Quadrant adjustments
400
+ __mmask8 negative_mask = _mm512_cmp_pd_mask(inputs, _mm512_setzero_pd(), _CMP_LT_OS);
401
+ __m512d values = _mm512_abs_pd(inputs);
402
+ __mmask8 reciprocal_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(1.0), _CMP_GT_OS);
403
+ values = _mm512_mask_div_pd(values, reciprocal_mask, _mm512_set1_pd(1.0), values);
404
+ __m512d const values_squared = _mm512_mul_pd(values, values);
405
+ __m512d const values_cubed = _mm512_mul_pd(values, values_squared);
406
+
407
+ // Polynomial evaluation (argument reduction and approximation)
408
+ __m512d polynomials = coeff_19;
409
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_18);
410
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_17);
411
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_16);
412
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_15);
413
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_14);
414
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_13);
415
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_12);
416
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_11);
417
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_10);
418
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_9);
419
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_8);
420
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_7);
421
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_6);
422
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_5);
423
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_4);
424
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_3);
425
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_2);
426
+ polynomials = _mm512_fmadd_pd(polynomials, values_squared, coeff_1);
427
+
428
+ // Compute atan approximation
429
+ __m512d result = _mm512_fmadd_pd(values_cubed, polynomials, values);
430
+ result = _mm512_mask_sub_pd(result, reciprocal_mask, _mm512_set1_pd(1.5707963267948966), result);
431
+ result = _mm512_mask_sub_pd(result, negative_mask, _mm512_setzero_pd(), result);
432
+ return result;
433
+ }
434
+
435
+ /**
436
+ * @brief AVX-512 implementation of atan2(y, x) for 8 double-precision values.
437
+ * @see Based on the f32x16 version with appropriate precision constants.
438
+ */
439
+ NK_INTERNAL __m512d nk_atan2_f64x8_skylake_(__m512d const ys_inputs, __m512d const xs_inputs) {
440
+ // Polynomial coefficients for atan approximation (higher precision than f32)
441
+ __m512d const coeff_19 = _mm512_set1_pd(-1.88796008463073496563746e-05);
442
+ __m512d const coeff_18 = _mm512_set1_pd(+0.000209850076645816976906797);
443
+ __m512d const coeff_17 = _mm512_set1_pd(-0.00110611831486672482563471);
444
+ __m512d const coeff_16 = _mm512_set1_pd(+0.00370026744188713119232403);
445
+ __m512d const coeff_15 = _mm512_set1_pd(-0.00889896195887655491740809);
446
+ __m512d const coeff_14 = _mm512_set1_pd(+0.016599329773529201970117);
447
+ __m512d const coeff_13 = _mm512_set1_pd(-0.0254517624932312641616861);
448
+ __m512d const coeff_12 = _mm512_set1_pd(+0.0337852580001353069993897);
449
+ __m512d const coeff_11 = _mm512_set1_pd(-0.0407629191276836500001934);
450
+ __m512d const coeff_10 = _mm512_set1_pd(+0.0466667150077840625632675);
451
+ __m512d const coeff_9 = _mm512_set1_pd(-0.0523674852303482457616113);
452
+ __m512d const coeff_8 = _mm512_set1_pd(+0.0587666392926673580854313);
453
+ __m512d const coeff_7 = _mm512_set1_pd(-0.0666573579361080525984562);
454
+ __m512d const coeff_6 = _mm512_set1_pd(+0.0769219538311769618355029);
455
+ __m512d const coeff_5 = _mm512_set1_pd(-0.090908995008245008229153);
456
+ __m512d const coeff_4 = _mm512_set1_pd(+0.111111105648261418443745);
457
+ __m512d const coeff_3 = _mm512_set1_pd(-0.14285714266771329383765);
458
+ __m512d const coeff_2 = _mm512_set1_pd(+0.199999999996591265594148);
459
+ __m512d const coeff_1 = _mm512_set1_pd(-0.333333333333311110369124);
460
+
461
+ // Quadrant adjustments normalizing to absolute values of x and y
462
+ __mmask8 const xs_negative_mask = _mm512_cmp_pd_mask(xs_inputs, _mm512_setzero_pd(), _CMP_LT_OS);
463
+ __m512d xs = _mm512_abs_pd(xs_inputs);
464
+ __m512d ys = _mm512_abs_pd(ys_inputs);
465
+ // Ensure proper fraction where the numerator is smaller than the denominator
466
+ __mmask8 const swap_mask = _mm512_cmp_pd_mask(ys, xs, _CMP_GT_OS);
467
+ __m512d temps = xs;
468
+ xs = _mm512_mask_blend_pd(swap_mask, xs, ys);
469
+ ys = _mm512_mask_sub_pd(ys, swap_mask, _mm512_setzero_pd(), temps);
470
+
471
+ // Compute ratio and ratio²
472
+ __m512d const ratio = _mm512_div_pd(ys, xs);
473
+ __m512d const ratio_squared = _mm512_mul_pd(ratio, ratio);
474
+ __m512d const ratio_cubed = _mm512_mul_pd(ratio, ratio_squared);
475
+
476
+ // Polynomial evaluation
477
+ __m512d polynomials = coeff_19;
478
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_18);
479
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_17);
480
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_16);
481
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_15);
482
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_14);
483
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_13);
484
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_12);
485
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_11);
486
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_10);
487
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_9);
488
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_8);
489
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_7);
490
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_6);
491
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_5);
492
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_4);
493
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_3);
494
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_2);
495
+ polynomials = _mm512_fmadd_pd(polynomials, ratio_squared, coeff_1);
496
+
497
+ // Compute the result with quadrant adjustments
498
+ __m512d results = _mm512_fmadd_pd(ratio_cubed, polynomials, ratio);
499
+
500
+ // Compute quadrant value: 0 for x>=0 && !swap, 1 for x>=0 && swap,
501
+ // -2 for x<0 && !swap, -1 for x<0 && swap
502
+ __m512d quadrant = _mm512_setzero_pd();
503
+ quadrant = _mm512_mask_blend_pd(xs_negative_mask, quadrant, _mm512_set1_pd(-2.0));
504
+ __m512d quadrant_incremented = _mm512_add_pd(quadrant, _mm512_set1_pd(1.0));
505
+ quadrant = _mm512_mask_blend_pd(swap_mask, quadrant, quadrant_incremented);
506
+
507
+ // Adjust for quadrant: result += quadrant * π/2
508
+ results = _mm512_fmadd_pd(quadrant, _mm512_set1_pd(1.5707963267948966), results);
509
+
510
+ // Transfer sign from x (XOR with sign bit of x_input)
511
+ __m512d xs_sign = _mm512_and_pd(xs_inputs, _mm512_set1_pd(-0.0));
512
+ results = _mm512_xor_pd(results, xs_sign);
513
+
514
+ // Transfer sign from y (XOR with sign bit of y_input)
515
+ __m512d ys_sign = _mm512_and_pd(ys_inputs, _mm512_set1_pd(-0.0));
516
+ results = _mm512_xor_pd(results, ys_sign);
517
+
518
+ return results;
519
+ }
520
+
521
+ NK_PUBLIC void nk_each_sin_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
522
+ nk_size_t i = 0;
523
+ for (; i + 8 <= n; i += 8) {
524
+ __m512d angles = _mm512_loadu_pd(ins + i);
525
+ __m512d results = nk_sin_f64x8_skylake_(angles);
526
+ _mm512_storeu_pd(outs + i, results);
527
+ }
528
+ if (i < n) {
529
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFF, n - i);
530
+ __m512d angles = _mm512_maskz_loadu_pd(mask, ins + i);
531
+ __m512d results = nk_sin_f64x8_skylake_(angles);
532
+ _mm512_mask_storeu_pd(outs + i, mask, results);
533
+ }
534
+ }
535
+ NK_PUBLIC void nk_each_cos_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
536
+ nk_size_t i = 0;
537
+ for (; i + 8 <= n; i += 8) {
538
+ __m512d angles = _mm512_loadu_pd(ins + i);
539
+ __m512d results = nk_cos_f64x8_skylake_(angles);
540
+ _mm512_storeu_pd(outs + i, results);
541
+ }
542
+ if (i < n) {
543
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFF, n - i);
544
+ __m512d angles = _mm512_maskz_loadu_pd(mask, ins + i);
545
+ __m512d results = nk_cos_f64x8_skylake_(angles);
546
+ _mm512_mask_storeu_pd(outs + i, mask, results);
547
+ }
548
+ }
549
+ NK_PUBLIC void nk_each_atan_f64_skylake(nk_f64_t const *ins, nk_size_t n, nk_f64_t *outs) {
550
+ nk_size_t i = 0;
551
+ for (; i + 8 <= n; i += 8) {
552
+ __m512d angles = _mm512_loadu_pd(ins + i);
553
+ __m512d results = nk_atan_f64x8_skylake_(angles);
554
+ _mm512_storeu_pd(outs + i, results);
555
+ }
556
+ if (i < n) {
557
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFF, n - i);
558
+ __m512d angles = _mm512_maskz_loadu_pd(mask, ins + i);
559
+ __m512d results = nk_atan_f64x8_skylake_(angles);
560
+ _mm512_mask_storeu_pd(outs + i, mask, results);
561
+ }
562
+ }
563
+
564
+ /**
565
+ * @brief Sine approximation for 16 f16 values via f32 upcasting.
566
+ *
567
+ * Degree-5 polynomial with Cody-Waite range reduction in f32.
568
+ * Takes __m256i (f16 data), returns __m256i (f16 result).
569
+ */
570
+ NK_INTERNAL __m256i nk_sin_f16x16_skylake_(__m256i angles_f16x16) {
571
+ __m512 angles_f32x16 = _mm512_cvtph_ps(angles_f16x16);
572
+ // Cody-Waite range reduction constants
573
+ __m512 pi_hi_f32x16 = _mm512_set1_ps(3.1415927f);
574
+ __m512 pi_lo_f32x16 = _mm512_set1_ps(-8.742278e-8f);
575
+ __m512 pi_recip_f32x16 = _mm512_set1_ps(0.31830988618f);
576
+ __m512 c3_f32x16 = _mm512_set1_ps(-1.6666666641e-1f);
577
+ __m512 c5_f32x16 = _mm512_set1_ps(8.3333293855e-3f);
578
+
579
+ __m512 quotient_f32x16 = _mm512_mul_ps(angles_f32x16, pi_recip_f32x16);
580
+ __m512 rounded_f32x16 = _mm512_roundscale_ps(quotient_f32x16, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
581
+ // Use explicit rounding to match roundscale (MXCSR-independent)
582
+ __m512i multiple_i32x16 = _mm512_cvt_roundps_epi32(rounded_f32x16, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
583
+
584
+ angles_f32x16 = _mm512_fnmadd_ps(rounded_f32x16, pi_hi_f32x16, angles_f32x16);
585
+ angles_f32x16 = _mm512_fnmadd_ps(rounded_f32x16, pi_lo_f32x16, angles_f32x16);
586
+
587
+ __m512 x2_f32x16 = _mm512_mul_ps(angles_f32x16, angles_f32x16);
588
+ __m512 poly_f32x16 = _mm512_fmadd_ps(c5_f32x16, x2_f32x16, c3_f32x16);
589
+ poly_f32x16 = _mm512_mul_ps(poly_f32x16, x2_f32x16);
590
+ __m512 result_f32x16 = _mm512_fmadd_ps(poly_f32x16, angles_f32x16, angles_f32x16);
591
+
592
+ __mmask16 odd_mask = _mm512_test_epi32_mask(multiple_i32x16, _mm512_set1_epi32(1));
593
+ result_f32x16 = _mm512_mask_sub_ps(result_f32x16, odd_mask, _mm512_setzero_ps(), result_f32x16);
594
+ return _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT);
595
+ }
596
+
597
+ /**
598
+ * @brief Cosine approximation for 16 f16 values via f32 upcasting.
599
+ *
600
+ * Uses cos(x) = sin(x + pi/2) with Cody-Waite range reduction in f32.
601
+ */
602
+ NK_INTERNAL __m256i nk_cos_f16x16_skylake_(__m256i angles_f16x16) {
603
+ __m512 angles_f32x16 = _mm512_cvtph_ps(angles_f16x16);
604
+ __m512 pi_hi_f32x16 = _mm512_set1_ps(3.1415927f);
605
+ __m512 pi_lo_f32x16 = _mm512_set1_ps(-8.742278e-8f);
606
+ __m512 pi_half_f32x16 = _mm512_set1_ps(1.5707963268f);
607
+ __m512 pi_recip_f32x16 = _mm512_set1_ps(0.31830988618f);
608
+ __m512 half_f32x16 = _mm512_set1_ps(0.5f);
609
+ __m512 c3_f32x16 = _mm512_set1_ps(-1.6666666641e-1f);
610
+ __m512 c5_f32x16 = _mm512_set1_ps(8.3333293855e-3f);
611
+
612
+ __m512 quotient_f32x16 = _mm512_fmsub_ps(angles_f32x16, pi_recip_f32x16, half_f32x16);
613
+ __m512 rounded_f32x16 = _mm512_roundscale_ps(quotient_f32x16, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
614
+ // Use explicit rounding to match roundscale (MXCSR-independent)
615
+ __m512i multiple_i32x16 = _mm512_cvt_roundps_epi32(rounded_f32x16, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
616
+
617
+ __m512 shift_f32x16 = _mm512_fmadd_ps(rounded_f32x16, pi_hi_f32x16, pi_half_f32x16);
618
+ angles_f32x16 = _mm512_sub_ps(angles_f32x16, shift_f32x16);
619
+ angles_f32x16 = _mm512_fnmadd_ps(rounded_f32x16, pi_lo_f32x16, angles_f32x16);
620
+
621
+ __m512 x2_f32x16 = _mm512_mul_ps(angles_f32x16, angles_f32x16);
622
+ __m512 poly_f32x16 = _mm512_fmadd_ps(c5_f32x16, x2_f32x16, c3_f32x16);
623
+ poly_f32x16 = _mm512_mul_ps(poly_f32x16, x2_f32x16);
624
+ __m512 result_f32x16 = _mm512_fmadd_ps(poly_f32x16, angles_f32x16, angles_f32x16);
625
+
626
+ __mmask16 even_mask = _mm512_testn_epi32_mask(multiple_i32x16, _mm512_set1_epi32(1));
627
+ result_f32x16 = _mm512_mask_sub_ps(result_f32x16, even_mask, _mm512_setzero_ps(), result_f32x16);
628
+ return _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT);
629
+ }
630
+
631
+ /**
632
+ * @brief Arctangent approximation for 16 f16 values via f32 upcasting.
633
+ *
634
+ * Degree-9 polynomial in f32 with quadrant adjustments.
635
+ */
636
+ NK_INTERNAL __m256i nk_atan_f16x16_skylake_(__m256i values_f16x16) {
637
+ __m512 values_f32x16 = _mm512_cvtph_ps(values_f16x16);
638
+ __m512 c3_f32x16 = _mm512_set1_ps(-0.3333333333f);
639
+ __m512 c5_f32x16 = _mm512_set1_ps(0.2f);
640
+ __m512 c7_f32x16 = _mm512_set1_ps(-0.1428571429f);
641
+ __m512 c9_f32x16 = _mm512_set1_ps(0.1111111111f);
642
+ __m512 pi_half_f32x16 = _mm512_set1_ps(1.5707963268f);
643
+ __m512 one_f32x16 = _mm512_set1_ps(1.0f);
644
+
645
+ __mmask16 negative_mask = _mm512_cmp_ps_mask(values_f32x16, _mm512_setzero_ps(), _CMP_LT_OS);
646
+ values_f32x16 = _mm512_abs_ps(values_f32x16);
647
+ __mmask16 reciprocal_mask = _mm512_cmp_ps_mask(values_f32x16, one_f32x16, _CMP_GT_OS);
648
+ values_f32x16 = _mm512_mask_div_ps(values_f32x16, reciprocal_mask, one_f32x16, values_f32x16);
649
+
650
+ __m512 x2_f32x16 = _mm512_mul_ps(values_f32x16, values_f32x16);
651
+ __m512 x3_f32x16 = _mm512_mul_ps(values_f32x16, x2_f32x16);
652
+
653
+ __m512 poly_f32x16 = c9_f32x16;
654
+ poly_f32x16 = _mm512_fmadd_ps(poly_f32x16, x2_f32x16, c7_f32x16);
655
+ poly_f32x16 = _mm512_fmadd_ps(poly_f32x16, x2_f32x16, c5_f32x16);
656
+ poly_f32x16 = _mm512_fmadd_ps(poly_f32x16, x2_f32x16, c3_f32x16);
657
+
658
+ __m512 result_f32x16 = _mm512_fmadd_ps(x3_f32x16, poly_f32x16, values_f32x16);
659
+ result_f32x16 = _mm512_mask_sub_ps(result_f32x16, reciprocal_mask, pi_half_f32x16, result_f32x16);
660
+ result_f32x16 = _mm512_mask_sub_ps(result_f32x16, negative_mask, _mm512_setzero_ps(), result_f32x16);
661
+ return _mm512_cvtps_ph(result_f32x16, _MM_FROUND_TO_NEAREST_INT);
662
+ }
663
+
664
+ NK_PUBLIC void nk_each_sin_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
665
+ nk_size_t i = 0;
666
+ for (; i + 16 <= n; i += 16) {
667
+ __m256i angles_f16x16 = _mm256_loadu_si256((__m256i const *)(ins + i));
668
+ __m256i result_f16x16 = nk_sin_f16x16_skylake_(angles_f16x16);
669
+ _mm256_storeu_si256((__m256i *)(outs + i), result_f16x16);
670
+ }
671
+ if (i < n) {
672
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
673
+ __m256i angles_f16x16 = _mm256_maskz_loadu_epi16(mask, ins + i);
674
+ __m256i result_f16x16 = nk_sin_f16x16_skylake_(angles_f16x16);
675
+ _mm256_mask_storeu_epi16(outs + i, mask, result_f16x16);
676
+ }
677
+ }
678
+
679
+ NK_PUBLIC void nk_each_cos_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
680
+ nk_size_t i = 0;
681
+ for (; i + 16 <= n; i += 16) {
682
+ __m256i angles_f16x16 = _mm256_loadu_si256((__m256i const *)(ins + i));
683
+ __m256i result_f16x16 = nk_cos_f16x16_skylake_(angles_f16x16);
684
+ _mm256_storeu_si256((__m256i *)(outs + i), result_f16x16);
685
+ }
686
+ if (i < n) {
687
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
688
+ __m256i angles_f16x16 = _mm256_maskz_loadu_epi16(mask, ins + i);
689
+ __m256i result_f16x16 = nk_cos_f16x16_skylake_(angles_f16x16);
690
+ _mm256_mask_storeu_epi16(outs + i, mask, result_f16x16);
691
+ }
692
+ }
693
+
694
+ NK_PUBLIC void nk_each_atan_f16_skylake(nk_f16_t const *ins, nk_size_t n, nk_f16_t *outs) {
695
+ nk_size_t i = 0;
696
+ for (; i + 16 <= n; i += 16) {
697
+ __m256i values_f16x16 = _mm256_loadu_si256((__m256i const *)(ins + i));
698
+ __m256i result_f16x16 = nk_atan_f16x16_skylake_(values_f16x16);
699
+ _mm256_storeu_si256((__m256i *)(outs + i), result_f16x16);
700
+ }
701
+ if (i < n) {
702
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n - i);
703
+ __m256i values_f16x16 = _mm256_maskz_loadu_epi16(mask, ins + i);
704
+ __m256i result_f16x16 = nk_atan_f16x16_skylake_(values_f16x16);
705
+ _mm256_mask_storeu_epi16(outs + i, mask, result_f16x16);
706
+ }
707
+ }
708
+
709
+ #if defined(__clang__)
710
+ #pragma clang attribute pop
711
+ #elif defined(__GNUC__)
712
+ #pragma GCC pop_options
713
+ #endif
714
+
715
+ #if defined(__cplusplus)
716
+ } // extern "C"
717
+ #endif
718
+
719
+ #endif // NK_TARGET_SKYLAKE
720
+ #endif // NK_TARGET_X86_
721
+ #endif // NK_TRIGONOMETRY_SKYLAKE_H