numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -0,0 +1,752 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for POWER9 VSX.
3
+ * @file include/numkong/dot/powervsx.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_powervsx_instructions Power9 VSX Dot Product Instructions
10
+ *
11
+ * Key Power9 VSX instructions for dot products:
12
+ *
13
+ * Intrinsic Instruction POWER9
14
+ * vec_madd(a, b, c) XVMADDADP/XVMADDASP 5cy FMA: a×b+c
15
+ * vec_msub(a, b, c) XVMSUBADP/XVMSUBASP 5cy FMS: a×b−c
16
+ * vec_msum(a, b, c) VMSUMUBM/VMSUMMBM 5cy i8/u8 widening multiply-sum → i32/u32
17
+ * vec_msum(a, b, c) VMSUMSHM/VMSUMUHM 5cy i16/u16 widening multiply-sum → i32/u32
18
+ * vec_doublee(a) XVCVSPDP 3cy Widen even f32 lanes → f64x2
19
+ * vec_doubleo(a) XVCVSPDP (odd) 3cy Widen odd f32 lanes → f64x2
20
+ * vec_unpackh(a) VUPKHSB/VUPKHSH 2cy Sign-extend high half (i8→i16 or i16→i32)
21
+ * vec_unpackl(a) VUPKLSB/VUPKLSH 2cy Sign-extend low half (i8→i16 or i16→i32)
22
+ * vec_xor(a, b) VXOR/XXLXOR 1cy Bitwise XOR
23
+ * vec_xl(off, ptr) LXV 5cy Aligned 16-byte load
24
+ * vec_xl_len(ptr, len) LXVL 5cy Partial load (Power9), zero-fills tail
25
+ * vec_extract_fp32_from_shorth XVCVHPSP (high) 5cy f16x4 → f32x4 from high half
26
+ * vec_extract_fp32_from_shortl XVCVHPSP (low) 5cy f16x4 → f32x4 from low half
27
+ * vec_popcnt(a) VPOPCNTB/H/W/D 2cy Per-element popcount
28
+ * vec_sum4s(a, b) VSUM4UBS/VSUM4SBS 5cy Sum groups of 4 bytes → i32/u32
29
+ * vec_sums(a, b) VSUMSWS 5cy Signed i32x4 horizontal → i32 (lane 3)
30
+ *
31
+ * Power9 (POWER ISA 3.0) provides `vec_xl_len` for partial loads that zero-fill unused bytes,
32
+ * enabling branchless tail handling: zero × anything = zero, so partial vectors contribute
33
+ * no spurious terms to dot-product accumulators.
34
+ *
35
+ * @section dot_powervsx_stateful Stateful Streaming Logic
36
+ *
37
+ * For memory-optimal tiled algorithms, this file defines state structures and force-inlined
38
+ * `NK_INTERNAL` functions:
39
+ *
40
+ * - nk_dot_f32x2 state for f32 inputs with double-precision accumulation,
41
+ * - nk_dot_f64x2 state with Dot2 stable dot-products for f64 inputs,
42
+ * - nk_dot_bf16x8 state for bf16 inputs with f32 accumulation,
43
+ * - nk_dot_f16x8 state for f16 inputs with f32 accumulation,
44
+ * - nk_dot_i8x16 state for i8 inputs with i32 accumulation,
45
+ * - nk_dot_u8x16 state for u8 inputs with u32 accumulation,
46
+ * - nk_dot_u1x128 state for binary inputs with u64 popcount accumulation.
47
+ */
48
+ #ifndef NK_DOT_POWERVSX_H
49
+ #define NK_DOT_POWERVSX_H
50
+
51
+ #if NK_TARGET_POWERVSX
52
+
53
+ #if defined(__cplusplus)
54
+ extern "C" {
55
+ #endif
56
+
57
+ #if defined(__clang__)
58
+ #pragma clang attribute push(__attribute__((target("power9-vector"))), apply_to = function)
59
+ #elif defined(__GNUC__)
60
+ #pragma GCC push_options
61
+ #pragma GCC target("power9-vector")
62
+ #endif
63
+
64
+ /** @brief Horizontal sum of 4 f32 lanes → scalar f32. */
65
+ NK_INTERNAL nk_f32_t nk_hsum_f32x4_powervsx_(nk_vf32x4_t values_f32x4) {
66
+ // Rotate by 8 bytes (2 floats) and add → {v[0]+v[2], v[1]+v[3], ...}
67
+ nk_vf32x4_t rotated_f32x4 = vec_sld(values_f32x4, values_f32x4, 8);
68
+ nk_vf32x4_t partial_f32x4 = vec_add(values_f32x4, rotated_f32x4);
69
+ // Rotate by 4 bytes (1 float) and add → {v[0]+v[1]+v[2]+v[3], ...}
70
+ nk_vf32x4_t shifted_f32x4 = vec_sld(partial_f32x4, partial_f32x4, 4);
71
+ nk_vf32x4_t total_f32x4 = vec_add(partial_f32x4, shifted_f32x4);
72
+ return vec_extract(total_f32x4, 0);
73
+ }
74
+
75
+ /** @brief Horizontal sum of 2 f64 lanes → scalar f64 via xxpermdi (1 domain crossing). */
76
+ NK_INTERNAL nk_f64_t nk_hsum_f64x2_powervsx_(nk_vf64x2_t values_f64x2) {
77
+ nk_vf64x2_t swapped_f64x2 = vec_xxpermdi(values_f64x2, values_f64x2, 2);
78
+ nk_vf64x2_t sum_f64x2 = vec_add(values_f64x2, swapped_f64x2);
79
+ return vec_extract(sum_f64x2, 0);
80
+ }
81
+
82
+ /** @brief Horizontal sum of 4 signed i32 lanes → scalar i32. */
83
+ NK_INTERNAL nk_i32_t nk_hsum_i32x4_powervsx_(nk_vi32x4_t values_i32x4) {
84
+ // vec_sums reduces i32x4 → i32 in lane 3 of the result
85
+ nk_vi32x4_t zero_i32x4 = vec_splats((nk_i32_t)0);
86
+ nk_vi32x4_t sums_i32x4 = vec_sums(values_i32x4, zero_i32x4);
87
+ return vec_extract(sums_i32x4, 3);
88
+ }
89
+
90
+ /** @brief Horizontal sum of 4 unsigned u32 lanes → scalar u32. */
91
+ NK_INTERNAL nk_u32_t nk_hsum_u32x4_powervsx_(nk_vu32x4_t values_u32x4) {
92
+ // Rotate by 8 bytes (2 ints) and add → {v[0]+v[2], v[1]+v[3], ...}
93
+ nk_vu32x4_t rotated_u32x4 = vec_sld(values_u32x4, values_u32x4, 8);
94
+ nk_vu32x4_t partial_u32x4 = vec_add(values_u32x4, rotated_u32x4);
95
+ // Rotate by 4 bytes (1 int) and add → {v[0]+v[1]+v[2]+v[3], ...}
96
+ nk_vu32x4_t shifted_u32x4 = vec_sld(partial_u32x4, partial_u32x4, 4);
97
+ nk_vu32x4_t total_u32x4 = vec_add(partial_u32x4, shifted_u32x4);
98
+ return vec_extract(total_u32x4, 0);
99
+ }
100
+
101
+ /** @brief Horizontal sum of 2 unsigned u64 lanes → scalar u64 via xxpermdi. */
102
+ NK_INTERNAL nk_u64_t nk_hsum_u64x2_powervsx_(nk_vu64x2_t values_u64x2) {
103
+ nk_vu64x2_t swapped_u64x2 = vec_xxpermdi(values_u64x2, values_u64x2, 2);
104
+ nk_vu64x2_t sum_u64x2 = vec_add(values_u64x2, swapped_u64x2);
105
+ return vec_extract(sum_u64x2, 0);
106
+ }
107
+
108
+ /** @brief Compensated horizontal sum of 2 f64 lanes via TwoSum. */
109
+ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_powervsx_(nk_vf64x2_t sum_f64x2, nk_vf64x2_t compensation_f64x2) {
110
+ // TwoSum merge of sum + compensation (2-wide)
111
+ nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, compensation_f64x2);
112
+ nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
113
+ nk_vf64x2_t rounding_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
114
+ vec_sub(compensation_f64x2, virtual_addend_f64x2));
115
+ // Scalar TwoSum 2 → 1
116
+ nk_f64_t lower_sum = vec_extract(tentative_sum_f64x2, 0);
117
+ nk_f64_t upper_sum = vec_extract(tentative_sum_f64x2, 1);
118
+ nk_f64_t lower_error = vec_extract(rounding_error_f64x2, 0);
119
+ nk_f64_t upper_error = vec_extract(rounding_error_f64x2, 1);
120
+ nk_f64_t tentative_sum = lower_sum + upper_sum;
121
+ nk_f64_t virtual_addend = tentative_sum - lower_sum;
122
+ nk_f64_t rounding_error = (lower_sum - (tentative_sum - virtual_addend)) + (upper_sum - virtual_addend);
123
+ return tentative_sum + (lower_error + upper_error + rounding_error);
124
+ }
125
+
126
+ #pragma region F32 and F64 Floats
127
+
128
+ NK_PUBLIC void nk_dot_f32_powervsx(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
129
+ nk_f64_t *result) {
130
+ // Upcast f32 → f64 for accumulation via vec_doublee (even lanes) and vec_doubleo (odd lanes)
131
+ nk_vf64x2_t sum_even_f64x2 = vec_splats((nk_f64_t)0);
132
+ nk_vf64x2_t sum_odd_f64x2 = vec_splats((nk_f64_t)0);
133
+ nk_vf32x4_t a_f32x4, b_f32x4;
134
+ nk_size_t tail_bytes;
135
+
136
+ nk_dot_f32_powervsx_cycle:
137
+ if (count_scalars < 4) {
138
+ tail_bytes = count_scalars * sizeof(nk_f32_t);
139
+ a_f32x4 = vec_xl_len((nk_f32_t *)a_scalars, tail_bytes);
140
+ b_f32x4 = vec_xl_len((nk_f32_t *)b_scalars, tail_bytes);
141
+ count_scalars = 0;
142
+ }
143
+ else {
144
+ a_f32x4 = vec_xl(0, a_scalars);
145
+ b_f32x4 = vec_xl(0, b_scalars);
146
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
147
+ }
148
+
149
+ // Widen even/odd f32 lanes → f64x2, then FMA
150
+ nk_vf64x2_t a_even_f64x2 = vec_doublee(a_f32x4);
151
+ nk_vf64x2_t b_even_f64x2 = vec_doublee(b_f32x4);
152
+ nk_vf64x2_t a_odd_f64x2 = vec_doubleo(a_f32x4);
153
+ nk_vf64x2_t b_odd_f64x2 = vec_doubleo(b_f32x4);
154
+ sum_even_f64x2 = vec_madd(a_even_f64x2, b_even_f64x2, sum_even_f64x2);
155
+ sum_odd_f64x2 = vec_madd(a_odd_f64x2, b_odd_f64x2, sum_odd_f64x2);
156
+
157
+ if (count_scalars) goto nk_dot_f32_powervsx_cycle;
158
+ // Combine even and odd accumulators → final scalar
159
+ nk_vf64x2_t total_f64x2 = vec_add(sum_even_f64x2, sum_odd_f64x2);
160
+ *result = nk_hsum_f64x2_powervsx_(total_f64x2);
161
+ }
162
+
163
+ NK_PUBLIC void nk_dot_f64_powervsx(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
164
+ nk_f64_t *result) {
165
+ // Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
166
+ nk_vf64x2_t sum_f64x2 = vec_splats((nk_f64_t)0);
167
+ nk_vf64x2_t compensation_f64x2 = vec_splats((nk_f64_t)0);
168
+ nk_vf64x2_t a_f64x2, b_f64x2;
169
+ nk_size_t tail_bytes;
170
+
171
+ nk_dot_f64_powervsx_cycle:
172
+ if (count_scalars < 2) {
173
+ tail_bytes = count_scalars * sizeof(nk_f64_t);
174
+ a_f64x2 = vec_xl_len((nk_f64_t *)a_scalars, tail_bytes);
175
+ b_f64x2 = vec_xl_len((nk_f64_t *)b_scalars, tail_bytes);
176
+ count_scalars = 0;
177
+ }
178
+ else {
179
+ a_f64x2 = vec_xl(0, a_scalars);
180
+ b_f64x2 = vec_xl(0, b_scalars);
181
+ a_scalars += 2, b_scalars += 2, count_scalars -= 2;
182
+ }
183
+
184
+ // TwoProd: product = a * b, error = msub(a, b, product) captures rounding error
185
+ nk_vf64x2_t product_f64x2 = vec_mul(a_f64x2, b_f64x2);
186
+ nk_vf64x2_t product_error_f64x2 = vec_msub(a_f64x2, b_f64x2, product_f64x2);
187
+ // TwoSum: (t, q) = TwoSum(sum, product) where t = sum + product rounded, q = error
188
+ nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, product_f64x2);
189
+ nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
190
+ nk_vf64x2_t sum_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
191
+ vec_sub(product_f64x2, virtual_addend_f64x2));
192
+ // Update: sum = t, compensation += q + r
193
+ sum_f64x2 = tentative_sum_f64x2;
194
+ compensation_f64x2 = vec_add(compensation_f64x2, vec_add(sum_error_f64x2, product_error_f64x2));
195
+
196
+ if (count_scalars) goto nk_dot_f64_powervsx_cycle;
197
+ // Compensated horizontal reduction preserving Dot2 error tracking
198
+ *result = nk_dot_stable_sum_f64x2_powervsx_(sum_f64x2, compensation_f64x2);
199
+ }
200
+
201
+ #pragma endregion F32 and F64 Floats
202
+ #pragma region F16 and BF16 Floats
203
+
204
+ NK_PUBLIC void nk_dot_bf16_powervsx(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
205
+ nk_f32_t *result) {
206
+ // bf16 → f32 via mergeh/mergel with zero: shift 16 bits into f32 upper half
207
+ nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
208
+ nk_vf32x4_t sum_f32x4 = vec_splats((nk_f32_t)0);
209
+ nk_vu16x8_t a_u16x8, b_u16x8;
210
+ nk_size_t tail_bytes;
211
+
212
+ nk_dot_bf16_powervsx_cycle:
213
+ if (count_scalars < 8) {
214
+ tail_bytes = count_scalars * sizeof(nk_bf16_t);
215
+ a_u16x8 = vec_xl_len((nk_u16_t *)a_scalars, tail_bytes);
216
+ b_u16x8 = vec_xl_len((nk_u16_t *)b_scalars, tail_bytes);
217
+ count_scalars = 0;
218
+ }
219
+ else {
220
+ a_u16x8 = vec_xl(0, (nk_u16_t const *)a_scalars);
221
+ b_u16x8 = vec_xl(0, (nk_u16_t const *)b_scalars);
222
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
223
+ }
224
+
225
+ // Convert bf16 → f32: merge with zero puts bf16 bits in upper 16 of each f32
226
+ nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
227
+ nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
228
+ nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
229
+ nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
230
+ sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, sum_f32x4);
231
+ sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, sum_f32x4);
232
+
233
+ if (count_scalars) goto nk_dot_bf16_powervsx_cycle;
234
+ *result = nk_hsum_f32x4_powervsx_(sum_f32x4);
235
+ }
236
+
237
+ NK_PUBLIC void nk_dot_f16_powervsx(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
238
+ nk_f32_t *result) {
239
+ // f16 → f32 via vec_extract_fp32_from_shorth/shortl (Power9 XVCVHPSP)
240
+ nk_vf32x4_t sum_f32x4 = vec_splats((nk_f32_t)0);
241
+ nk_vu16x8_t a_u16x8, b_u16x8;
242
+ nk_size_t tail_bytes;
243
+
244
+ nk_dot_f16_powervsx_cycle:
245
+ if (count_scalars < 8) {
246
+ tail_bytes = count_scalars * sizeof(nk_f16_t);
247
+ a_u16x8 = vec_xl_len((nk_u16_t *)a_scalars, tail_bytes);
248
+ b_u16x8 = vec_xl_len((nk_u16_t *)b_scalars, tail_bytes);
249
+ count_scalars = 0;
250
+ }
251
+ else {
252
+ a_u16x8 = vec_xl(0, (nk_u16_t const *)a_scalars);
253
+ b_u16x8 = vec_xl(0, (nk_u16_t const *)b_scalars);
254
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
255
+ }
256
+
257
+ // Convert f16 → f32 via hardware XVCVHPSP
258
+ nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
259
+ nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
260
+ nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
261
+ nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
262
+ sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, sum_f32x4);
263
+ sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, sum_f32x4);
264
+
265
+ if (count_scalars) goto nk_dot_f16_powervsx_cycle;
266
+ *result = nk_hsum_f32x4_powervsx_(sum_f32x4);
267
+ }
268
+
269
+ #pragma endregion F16 and BF16 Floats
270
+ #pragma region I8 and U8 Integers
271
+
272
+ NK_PUBLIC void nk_dot_i8_powervsx(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
273
+ nk_i32_t *result) {
274
+ // Algebraic transform for i8×i8 using VMSUMMBM (i8×u8 → i32):
275
+ // b' = b ⊕ 0x80 (reinterpret signed as unsigned)
276
+ // a·b = a·b' − 128·Σa
277
+ // Σ(a+128) accumulated via VSUM4UBS; correction applied after loop.
278
+ // Tail handling is free: vec_xl_len zero-fills unused lanes.
279
+ // - Product: 0 × (0⊕0x80) = 0 → no spurious contribution
280
+ // - Correction: (0⊕0x80) = 128 in sum_a_biased, compensated by count_padded
281
+ nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
282
+ nk_vi32x4_t accumulator_i32x4 = vec_splats((nk_i32_t)0);
283
+ nk_vu32x4_t sum_a_biased_u32x4 = vec_splats((nk_u32_t)0);
284
+ nk_size_t count_padded = ((count_scalars + 15) / 16) * 16;
285
+ nk_vi8x16_t a_i8x16;
286
+ nk_vu8x16_t b_biased_u8x16;
287
+ nk_size_t tail_bytes;
288
+
289
+ nk_dot_i8_powervsx_cycle:
290
+ if (count_scalars < 16) {
291
+ tail_bytes = count_scalars * sizeof(nk_i8_t);
292
+ a_i8x16 = vec_xl_len((nk_i8_t *)a_scalars, tail_bytes);
293
+ b_biased_u8x16 = vec_xor(vec_xl_len((nk_u8_t *)b_scalars, tail_bytes), bias_u8x16);
294
+ count_scalars = 0;
295
+ }
296
+ else {
297
+ a_i8x16 = vec_xl(0, a_scalars);
298
+ b_biased_u8x16 = vec_xor(vec_xl(0, (nk_u8_t *)b_scalars), bias_u8x16);
299
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
300
+ }
301
+
302
+ // VMSUMMBM: i8 × u8 → i32 (16 products per instruction)
303
+ accumulator_i32x4 = vec_msum(a_i8x16, b_biased_u8x16, accumulator_i32x4);
304
+ // VSUM4UBS: accumulate Σ(a+128) as unsigned (independent chain, good ILP)
305
+ sum_a_biased_u32x4 = vec_sum4s(vec_xor((nk_vu8x16_t)a_i8x16, bias_u8x16), sum_a_biased_u32x4);
306
+
307
+ if (count_scalars) goto nk_dot_i8_powervsx_cycle;
308
+
309
+ // Correction: a·b = biased_dot − 128·Σa = biased_dot − 128·(Σ(a+128) − 128·count_padded)
310
+ nk_i32_t biased_dot = nk_hsum_i32x4_powervsx_(accumulator_i32x4);
311
+ nk_i64_t correction = 128LL * (nk_i64_t)nk_hsum_u32x4_powervsx_(sum_a_biased_u32x4) -
312
+ 16384LL * (nk_i64_t)count_padded;
313
+ *result = (nk_i32_t)((nk_i64_t)biased_dot - correction);
314
+ }
315
+
316
+ NK_PUBLIC void nk_dot_u8_powervsx(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
317
+ nk_u32_t *result) {
318
+ // vec_msum: multiply u8×u8 pairs and accumulate 16 products → 4 u32 lanes per call
319
+ nk_vu32x4_t accumulator_u32x4 = vec_splats((nk_u32_t)0);
320
+ nk_vu8x16_t a_u8x16, b_u8x16;
321
+ nk_size_t tail_bytes;
322
+
323
+ nk_dot_u8_powervsx_cycle:
324
+ if (count_scalars < 16) {
325
+ tail_bytes = count_scalars * sizeof(nk_u8_t);
326
+ a_u8x16 = vec_xl_len((nk_u8_t *)a_scalars, tail_bytes);
327
+ b_u8x16 = vec_xl_len((nk_u8_t *)b_scalars, tail_bytes);
328
+ count_scalars = 0;
329
+ }
330
+ else {
331
+ a_u8x16 = vec_xl(0, a_scalars);
332
+ b_u8x16 = vec_xl(0, b_scalars);
333
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
334
+ }
335
+
336
+ // Unsigned × unsigned multiply-sum: 16 u8 products accumulated into 4 u32 lanes
337
+ accumulator_u32x4 = vec_msum(a_u8x16, b_u8x16, accumulator_u32x4);
338
+
339
+ if (count_scalars) goto nk_dot_u8_powervsx_cycle;
340
+ *result = nk_hsum_u32x4_powervsx_(accumulator_u32x4);
341
+ }
342
+
343
+ #pragma endregion I8 and U8 Integers
344
+ #pragma region Binary
345
+
346
+ NK_PUBLIC void nk_dot_u1_powervsx(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
347
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
348
+ nk_vu64x2_t accumulator_u64x2 = vec_splats((nk_u64_t)0);
349
+ nk_vu8x16_t a_u8x16, b_u8x16;
350
+
351
+ nk_dot_u1_powervsx_cycle:
352
+ if (n_bytes < 16) {
353
+ a_u8x16 = vec_xl_len((nk_u8_t *)a, n_bytes);
354
+ b_u8x16 = vec_xl_len((nk_u8_t *)b, n_bytes);
355
+ n_bytes = 0;
356
+ }
357
+ else {
358
+ a_u8x16 = vec_xl(0, (nk_u8_t const *)a);
359
+ b_u8x16 = vec_xl(0, (nk_u8_t const *)b);
360
+ a += 16, b += 16, n_bytes -= 16;
361
+ }
362
+
363
+ // AND → doubleword popcount (vpopcntd) → accumulate u64 lanes
364
+ nk_vu8x16_t and_u8x16 = vec_and(a_u8x16, b_u8x16);
365
+ nk_vu64x2_t popcnt_u64x2 = vec_popcnt((nk_vu64x2_t)and_u8x16);
366
+ accumulator_u64x2 = vec_add(accumulator_u64x2, popcnt_u64x2);
367
+
368
+ if (n_bytes) goto nk_dot_u1_powervsx_cycle;
369
+ *result = (nk_u32_t)nk_hsum_u64x2_powervsx_(accumulator_u64x2);
370
+ }
371
+
372
+ #pragma endregion Binary
373
+
374
+ /**
375
+ * @brief Running state for 128-bit dot accumulation over f32 scalars on Power VSX.
376
+ *
377
+ * Processes 2 f32 values at a time, upcasting to f64 for accumulation to avoid
378
+ * catastrophic cancellation in long reductions.
379
+ */
380
+ typedef struct nk_dot_f32x2_state_powervsx_t {
381
+ nk_vf64x2_t sum_f64x2;
382
+ } nk_dot_f32x2_state_powervsx_t;
383
+
384
+ NK_INTERNAL void nk_dot_f32x2_init_powervsx(nk_dot_f32x2_state_powervsx_t *state) {
385
+ state->sum_f64x2 = vec_splats((nk_f64_t)0);
386
+ }
387
+
388
+ NK_INTERNAL void nk_dot_f32x2_update_powervsx(nk_dot_f32x2_state_powervsx_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
389
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
390
+ nk_unused_(depth_offset);
391
+ nk_unused_(active_dimensions);
392
+ // Load 8 bytes (2 f32s) into a vector register, zero-filling the upper 8 bytes
393
+ nk_vf32x4_t a_f32x4 = vec_xl_len((nk_f32_t *)a.f32s, 8);
394
+ nk_vf32x4_t b_f32x4 = vec_xl_len((nk_f32_t *)b.f32s, 8);
395
+ // Widen even lanes (the two f32 values) → f64x2
396
+ nk_vf64x2_t a_f64x2 = vec_doublee(a_f32x4);
397
+ nk_vf64x2_t b_f64x2 = vec_doublee(b_f32x4);
398
+ // Permute to get {lane0, lane2} → {a[0], a[1]} as f64x2
399
+ a_f64x2 = vec_xxpermdi(a_f64x2, vec_doubleo(a_f32x4), 0);
400
+ b_f64x2 = vec_xxpermdi(b_f64x2, vec_doubleo(b_f32x4), 0);
401
+ state->sum_f64x2 = vec_madd(a_f64x2, b_f64x2, state->sum_f64x2);
402
+ }
403
+
404
+ NK_INTERNAL void nk_dot_f32x2_finalize_powervsx( //
405
+ nk_dot_f32x2_state_powervsx_t const *state_a, nk_dot_f32x2_state_powervsx_t const *state_b, //
406
+ nk_dot_f32x2_state_powervsx_t const *state_c, nk_dot_f32x2_state_powervsx_t const *state_d, //
407
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
408
+ nk_unused_(total_dimensions);
409
+ nk_vf64x2_t sum_a_f64x2 = vec_add(state_a->sum_f64x2, vec_xxpermdi(state_a->sum_f64x2, state_a->sum_f64x2, 2));
410
+ nk_vf64x2_t sum_b_f64x2 = vec_add(state_b->sum_f64x2, vec_xxpermdi(state_b->sum_f64x2, state_b->sum_f64x2, 2));
411
+ nk_vf64x2_t sum_c_f64x2 = vec_add(state_c->sum_f64x2, vec_xxpermdi(state_c->sum_f64x2, state_c->sum_f64x2, 2));
412
+ nk_vf64x2_t sum_d_f64x2 = vec_add(state_d->sum_f64x2, vec_xxpermdi(state_d->sum_f64x2, state_d->sum_f64x2, 2));
413
+ result->vf64x2s[0] = vec_xxpermdi(sum_a_f64x2, sum_b_f64x2, 0);
414
+ result->vf64x2s[1] = vec_xxpermdi(sum_c_f64x2, sum_d_f64x2, 0);
415
+ }
416
+
417
+ /**
418
+ * @brief Running state for 128-bit dot accumulation over f64 scalars on Power VSX.
419
+ *
420
+ * Uses the Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product.
421
+ */
422
+ typedef struct nk_dot_f64x2_state_powervsx_t {
423
+ nk_vf64x2_t sum_f64x2;
424
+ nk_vf64x2_t compensation_f64x2;
425
+ } nk_dot_f64x2_state_powervsx_t;
426
+
427
+ NK_INTERNAL void nk_dot_f64x2_init_powervsx(nk_dot_f64x2_state_powervsx_t *state) {
428
+ state->sum_f64x2 = vec_splats((nk_f64_t)0);
429
+ state->compensation_f64x2 = vec_splats((nk_f64_t)0);
430
+ }
431
+
432
+ NK_INTERNAL void nk_dot_f64x2_update_powervsx(nk_dot_f64x2_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
433
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
434
+ nk_unused_(depth_offset);
435
+ nk_unused_(active_dimensions);
436
+ nk_vf64x2_t sum_f64x2 = state->sum_f64x2;
437
+ nk_vf64x2_t compensation_f64x2 = state->compensation_f64x2;
438
+ nk_vf64x2_t a_f64x2 = a.vf64x2;
439
+ nk_vf64x2_t b_f64x2 = b.vf64x2;
440
+
441
+ // TwoProd: product = a × b, error = msub(a, b, product) captures rounding error
442
+ nk_vf64x2_t product_f64x2 = vec_mul(a_f64x2, b_f64x2);
443
+ nk_vf64x2_t product_error_f64x2 = vec_msub(a_f64x2, b_f64x2, product_f64x2);
444
+
445
+ // TwoSum: (t, q) = TwoSum(sum, product) where t = sum + product rounded, q = error
446
+ nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, product_f64x2);
447
+ nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
448
+ nk_vf64x2_t sum_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
449
+ vec_sub(product_f64x2, virtual_addend_f64x2));
450
+
451
+ // Update: sum = t, compensation += q + r
452
+ state->sum_f64x2 = tentative_sum_f64x2;
453
+ state->compensation_f64x2 = vec_add(compensation_f64x2, vec_add(sum_error_f64x2, product_error_f64x2));
454
+ }
455
+
456
+ NK_INTERNAL void nk_dot_f64x2_finalize_powervsx( //
457
+ nk_dot_f64x2_state_powervsx_t const *state_a, nk_dot_f64x2_state_powervsx_t const *state_b, //
458
+ nk_dot_f64x2_state_powervsx_t const *state_c, nk_dot_f64x2_state_powervsx_t const *state_d, //
459
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
460
+ nk_unused_(total_dimensions);
461
+ // Compensated horizontal reduction preserving Dot2 error tracking per state
462
+ result->f64s[0] = nk_dot_stable_sum_f64x2_powervsx_(state_a->sum_f64x2, state_a->compensation_f64x2);
463
+ result->f64s[1] = nk_dot_stable_sum_f64x2_powervsx_(state_b->sum_f64x2, state_b->compensation_f64x2);
464
+ result->f64s[2] = nk_dot_stable_sum_f64x2_powervsx_(state_c->sum_f64x2, state_c->compensation_f64x2);
465
+ result->f64s[3] = nk_dot_stable_sum_f64x2_powervsx_(state_d->sum_f64x2, state_d->compensation_f64x2);
466
+ }
467
+
468
+ /**
469
+ * @brief Running state for 128-bit dot accumulation over bf16 scalars on Power VSX.
470
+ *
471
+ * Processes 8 bf16 values at a time (128 bits), converting to f32 via vec_mergeh/mergel
472
+ * with zero for accumulation.
473
+ */
474
+ typedef struct nk_dot_bf16x8_state_powervsx_t {
475
+ nk_vf32x4_t sum_f32x4;
476
+ } nk_dot_bf16x8_state_powervsx_t;
477
+
478
+ NK_INTERNAL void nk_dot_bf16x8_init_powervsx(nk_dot_bf16x8_state_powervsx_t *state) {
479
+ state->sum_f32x4 = vec_splats((nk_f32_t)0);
480
+ }
481
+
482
+ NK_INTERNAL void nk_dot_bf16x8_update_powervsx(nk_dot_bf16x8_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
483
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
484
+ nk_unused_(depth_offset);
485
+ nk_unused_(active_dimensions);
486
+ // Convert bf16 → f32 inline: merge with zero puts bf16 bits in upper 16 of each f32
487
+ nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
488
+ nk_vu16x8_t a_u16x8 = a.vu16x8;
489
+ nk_vu16x8_t b_u16x8 = b.vu16x8;
490
+ nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
491
+ nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
492
+ nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
493
+ nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
494
+ state->sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, state->sum_f32x4);
495
+ state->sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, state->sum_f32x4);
496
+ }
497
+
498
+ NK_INTERNAL void nk_dot_bf16x8_finalize_powervsx( //
499
+ nk_dot_bf16x8_state_powervsx_t const *state_a, nk_dot_bf16x8_state_powervsx_t const *state_b, //
500
+ nk_dot_bf16x8_state_powervsx_t const *state_c, nk_dot_bf16x8_state_powervsx_t const *state_d, //
501
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
502
+ nk_unused_(total_dimensions);
503
+ nk_vf32x4_t a_f32x4 = state_a->sum_f32x4, b_f32x4 = state_b->sum_f32x4, c_f32x4 = state_c->sum_f32x4,
504
+ d_f32x4 = state_d->sum_f32x4;
505
+ nk_vf32x4_t transpose_ab_low_f32x4 = vec_mergeh(a_f32x4, b_f32x4);
506
+ nk_vf32x4_t transpose_cd_low_f32x4 = vec_mergeh(c_f32x4, d_f32x4);
507
+ nk_vf32x4_t transpose_ab_high_f32x4 = vec_mergel(a_f32x4, b_f32x4);
508
+ nk_vf32x4_t transpose_cd_high_f32x4 = vec_mergel(c_f32x4, d_f32x4);
509
+ nk_vf32x4_t sum_lane0_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
510
+ (nk_vu64x2_t)transpose_cd_low_f32x4, 0);
511
+ nk_vf32x4_t sum_lane1_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
512
+ (nk_vu64x2_t)transpose_cd_low_f32x4, 3);
513
+ nk_vf32x4_t sum_lane2_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
514
+ (nk_vu64x2_t)transpose_cd_high_f32x4, 0);
515
+ nk_vf32x4_t sum_lane3_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
516
+ (nk_vu64x2_t)transpose_cd_high_f32x4, 3);
517
+ result->vf32x4 = vec_add(vec_add(sum_lane0_f32x4, sum_lane1_f32x4), vec_add(sum_lane2_f32x4, sum_lane3_f32x4));
518
+ }
519
+
520
+ /**
521
+ * @brief Running state for 128-bit dot accumulation over f16 scalars on Power VSX.
522
+ *
523
+ * Processes 8 f16 values at a time (128 bits), converting to f32 via
524
+ * vec_extract_fp32_from_shorth/shortl for accumulation.
525
+ */
526
+ typedef struct nk_dot_f16x8_state_powervsx_t {
527
+ nk_vf32x4_t sum_f32x4;
528
+ } nk_dot_f16x8_state_powervsx_t;
529
+
530
+ NK_INTERNAL void nk_dot_f16x8_init_powervsx(nk_dot_f16x8_state_powervsx_t *state) {
531
+ state->sum_f32x4 = vec_splats((nk_f32_t)0);
532
+ }
533
+
534
+ NK_INTERNAL void nk_dot_f16x8_update_powervsx(nk_dot_f16x8_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
535
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
536
+ nk_unused_(depth_offset);
537
+ nk_unused_(active_dimensions);
538
+ // Convert f16 → f32 via hardware XVCVHPSP
539
+ nk_vu16x8_t a_u16x8 = a.vu16x8;
540
+ nk_vu16x8_t b_u16x8 = b.vu16x8;
541
+ nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
542
+ nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
543
+ nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
544
+ nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
545
+ state->sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, state->sum_f32x4);
546
+ state->sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, state->sum_f32x4);
547
+ }
548
+
549
+ NK_INTERNAL void nk_dot_f16x8_finalize_powervsx( //
550
+ nk_dot_f16x8_state_powervsx_t const *state_a, nk_dot_f16x8_state_powervsx_t const *state_b, //
551
+ nk_dot_f16x8_state_powervsx_t const *state_c, nk_dot_f16x8_state_powervsx_t const *state_d, //
552
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
553
+ nk_unused_(total_dimensions);
554
+ nk_vf32x4_t a_f32x4 = state_a->sum_f32x4, b_f32x4 = state_b->sum_f32x4, c_f32x4 = state_c->sum_f32x4,
555
+ d_f32x4 = state_d->sum_f32x4;
556
+ nk_vf32x4_t transpose_ab_low_f32x4 = vec_mergeh(a_f32x4, b_f32x4);
557
+ nk_vf32x4_t transpose_cd_low_f32x4 = vec_mergeh(c_f32x4, d_f32x4);
558
+ nk_vf32x4_t transpose_ab_high_f32x4 = vec_mergel(a_f32x4, b_f32x4);
559
+ nk_vf32x4_t transpose_cd_high_f32x4 = vec_mergel(c_f32x4, d_f32x4);
560
+ nk_vf32x4_t sum_lane0_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
561
+ (nk_vu64x2_t)transpose_cd_low_f32x4, 0);
562
+ nk_vf32x4_t sum_lane1_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
563
+ (nk_vu64x2_t)transpose_cd_low_f32x4, 3);
564
+ nk_vf32x4_t sum_lane2_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
565
+ (nk_vu64x2_t)transpose_cd_high_f32x4, 0);
566
+ nk_vf32x4_t sum_lane3_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
567
+ (nk_vu64x2_t)transpose_cd_high_f32x4, 3);
568
+ result->vf32x4 = vec_add(vec_add(sum_lane0_f32x4, sum_lane1_f32x4), vec_add(sum_lane2_f32x4, sum_lane3_f32x4));
569
+ }
570
+
571
+ /**
572
+ * @brief Running state for 128-bit dot accumulation over i8 scalars on Power VSX.
573
+ *
574
+ * Algebraic transform: a·b = a·(b⊕0x80) − 128·Σa. Uses VMSUMMBM (i8×u8 → i32) for the biased
575
+ * product. Correction is applied at finalize using precomputed column sums from the compensated
576
+ * macro infrastructure.
577
+ */
578
+ typedef struct nk_dot_i8x16_state_powervsx_t {
579
+ nk_vi32x4_t biased_sum_i32x4;
580
+ } nk_dot_i8x16_state_powervsx_t;
581
+
582
+ NK_INTERNAL void nk_dot_i8x16_init_powervsx(nk_dot_i8x16_state_powervsx_t *state) {
583
+ state->biased_sum_i32x4 = vec_splats((nk_i32_t)0);
584
+ }
585
+
586
+ NK_INTERNAL void nk_dot_i8x16_update_powervsx(nk_dot_i8x16_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
587
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
588
+ nk_unused_(depth_offset);
589
+ nk_unused_(active_dimensions);
590
+ // VMSUMMBM(b, a⊕0x80) = Σ(b_i · (a_i+128)) = a·b + 128·Σb
591
+ // Swapping operands: b in signed slot, biased a in unsigned slot.
592
+ // Correction −128·Σb uses precomputed B column sums from the compensated macro.
593
+ nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
594
+ nk_vu8x16_t a_biased_u8x16 = vec_xor(a.vu8x16, bias_u8x16);
595
+ state->biased_sum_i32x4 = vec_msum(b.vi8x16, a_biased_u8x16, state->biased_sum_i32x4);
596
+ }
597
+
598
+ NK_INTERNAL void nk_dot_i8x16_finalize_powervsx( //
599
+ nk_dot_i8x16_state_powervsx_t const *state_a, nk_dot_i8x16_state_powervsx_t const *state_b, //
600
+ nk_dot_i8x16_state_powervsx_t const *state_c, nk_dot_i8x16_state_powervsx_t const *state_d, //
601
+ nk_size_t total_dimensions, //
602
+ nk_i32_t a_sum, nk_b128_vec_t b_sums, nk_b128_vec_t *result) {
603
+ nk_unused_(total_dimensions);
604
+ nk_unused_(a_sum);
605
+
606
+ // Transpose-reduce biased products across 4 accumulators → one i32x4
607
+ nk_vi32x4_t a_i32x4 = state_a->biased_sum_i32x4, b_i32x4 = state_b->biased_sum_i32x4,
608
+ c_i32x4 = state_c->biased_sum_i32x4, d_i32x4 = state_d->biased_sum_i32x4;
609
+ nk_vi32x4_t transpose_ab_low_i32x4 = vec_mergeh(a_i32x4, b_i32x4);
610
+ nk_vi32x4_t transpose_cd_low_i32x4 = vec_mergeh(c_i32x4, d_i32x4);
611
+ nk_vi32x4_t transpose_ab_high_i32x4 = vec_mergel(a_i32x4, b_i32x4);
612
+ nk_vi32x4_t transpose_cd_high_i32x4 = vec_mergel(c_i32x4, d_i32x4);
613
+ nk_vi32x4_t sum_lane0_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_i32x4,
614
+ (nk_vu64x2_t)transpose_cd_low_i32x4, 0);
615
+ nk_vi32x4_t sum_lane1_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_i32x4,
616
+ (nk_vu64x2_t)transpose_cd_low_i32x4, 3);
617
+ nk_vi32x4_t sum_lane2_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_i32x4,
618
+ (nk_vu64x2_t)transpose_cd_high_i32x4, 0);
619
+ nk_vi32x4_t sum_lane3_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_i32x4,
620
+ (nk_vu64x2_t)transpose_cd_high_i32x4, 3);
621
+ nk_vi32x4_t biased_i32x4 = vec_add(vec_add(sum_lane0_i32x4, sum_lane1_i32x4),
622
+ vec_add(sum_lane2_i32x4, sum_lane3_i32x4));
623
+
624
+ // Correction: VMSUMMBM(b, a⊕0x80) = Σ(b_i·(a_i+128)) = a·b + 128·Σb
625
+ // So a·b = biased − 128·Σb. B column sums are precomputed during packing.
626
+ nk_vu32x4_t shift_u32x4 = vec_splats((nk_u32_t)7);
627
+ nk_vi32x4_t correction_i32x4 = (nk_vi32x4_t)vec_sl((nk_vu32x4_t)b_sums.vi32x4, shift_u32x4);
628
+ result->vi32x4 = vec_sub(biased_i32x4, correction_i32x4);
629
+ }
630
+
631
+ /** @brief Running state for i8 column sum precomputation on Power VSX. */
632
+ typedef struct nk_sum_i8x16_state_powervsx_t {
633
+ nk_vu32x4_t biased_sum_u32x4;
634
+ } nk_sum_i8x16_state_powervsx_t;
635
+
636
+ NK_INTERNAL void nk_sum_i8x16_init_powervsx(nk_sum_i8x16_state_powervsx_t *state) {
637
+ state->biased_sum_u32x4 = vec_splats((nk_u32_t)0);
638
+ }
639
+
640
+ NK_INTERNAL void nk_sum_i8x16_update_powervsx(nk_sum_i8x16_state_powervsx_t *state, nk_b128_vec_t values_vec) {
641
+ nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
642
+ nk_vu8x16_t biased_u8x16 = vec_xor(values_vec.vu8x16, bias_u8x16);
643
+ state->biased_sum_u32x4 = vec_sum4s(biased_u8x16, state->biased_sum_u32x4);
644
+ }
645
+
646
+ NK_INTERNAL nk_i32_t nk_sum_i8x16_finalize_powervsx(nk_sum_i8x16_state_powervsx_t const *state, nk_size_t count) {
647
+ nk_u32_t biased_sum = nk_hsum_u32x4_powervsx_(state->biased_sum_u32x4);
648
+ return (nk_i32_t)((nk_i64_t)biased_sum - 128 * (nk_i64_t)count);
649
+ }
650
+
651
+ /**
652
+ * @brief Running state for 128-bit dot accumulation over u8 scalars on Power VSX.
653
+ *
654
+ * Processes 16 u8 values at a time via vec_msum, accumulating into 4 u32 lanes.
655
+ */
656
+ typedef struct nk_dot_u8x16_state_powervsx_t {
657
+ nk_vu32x4_t sum_u32x4;
658
+ } nk_dot_u8x16_state_powervsx_t;
659
+
660
+ NK_INTERNAL void nk_dot_u8x16_init_powervsx(nk_dot_u8x16_state_powervsx_t *state) {
661
+ state->sum_u32x4 = vec_splats((nk_u32_t)0);
662
+ }
663
+
664
+ NK_INTERNAL void nk_dot_u8x16_update_powervsx(nk_dot_u8x16_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
665
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
666
+ nk_unused_(depth_offset);
667
+ nk_unused_(active_dimensions);
668
+ // Unsigned × unsigned multiply-sum: 16 u8 products accumulated into 4 u32 lanes
669
+ nk_vu8x16_t a_u8x16 = a.vu8x16;
670
+ nk_vu8x16_t b_u8x16 = b.vu8x16;
671
+ state->sum_u32x4 = vec_msum(a_u8x16, b_u8x16, state->sum_u32x4);
672
+ }
673
+
674
+ NK_INTERNAL void nk_dot_u8x16_finalize_powervsx( //
675
+ nk_dot_u8x16_state_powervsx_t const *state_a, nk_dot_u8x16_state_powervsx_t const *state_b, //
676
+ nk_dot_u8x16_state_powervsx_t const *state_c, nk_dot_u8x16_state_powervsx_t const *state_d, //
677
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
678
+ nk_unused_(total_dimensions);
679
+ nk_vu32x4_t a_u32x4 = state_a->sum_u32x4, b_u32x4 = state_b->sum_u32x4, c_u32x4 = state_c->sum_u32x4,
680
+ d_u32x4 = state_d->sum_u32x4;
681
+ nk_vu32x4_t transpose_ab_low_u32x4 = vec_mergeh(a_u32x4, b_u32x4);
682
+ nk_vu32x4_t transpose_cd_low_u32x4 = vec_mergeh(c_u32x4, d_u32x4);
683
+ nk_vu32x4_t transpose_ab_high_u32x4 = vec_mergel(a_u32x4, b_u32x4);
684
+ nk_vu32x4_t transpose_cd_high_u32x4 = vec_mergel(c_u32x4, d_u32x4);
685
+ nk_vu32x4_t sum_lane0_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_u32x4,
686
+ (nk_vu64x2_t)transpose_cd_low_u32x4, 0);
687
+ nk_vu32x4_t sum_lane1_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_u32x4,
688
+ (nk_vu64x2_t)transpose_cd_low_u32x4, 3);
689
+ nk_vu32x4_t sum_lane2_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_u32x4,
690
+ (nk_vu64x2_t)transpose_cd_high_u32x4, 0);
691
+ nk_vu32x4_t sum_lane3_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_u32x4,
692
+ (nk_vu64x2_t)transpose_cd_high_u32x4, 3);
693
+ result->vu32x4 = vec_add(vec_add(sum_lane0_u32x4, sum_lane1_u32x4), vec_add(sum_lane2_u32x4, sum_lane3_u32x4));
694
+ }
695
+
696
+ /**
697
+ * @brief Running state for 128-bit binary dot accumulation on Power VSX.
698
+ *
699
+ * Processes 128 bits (16 bytes) at a time via AND + doubleword popcount (vpopcntd),
700
+ * accumulating bit-match counts into 2 u64 lanes.
701
+ */
702
+ typedef struct nk_dot_u1x128_state_powervsx_t {
703
+ nk_vu64x2_t dot_count_u64x2;
704
+ } nk_dot_u1x128_state_powervsx_t;
705
+
706
+ NK_INTERNAL void nk_dot_u1x128_init_powervsx(nk_dot_u1x128_state_powervsx_t *state) {
707
+ state->dot_count_u64x2 = vec_splats((nk_u64_t)0);
708
+ }
709
+
710
+ NK_INTERNAL void nk_dot_u1x128_update_powervsx(nk_dot_u1x128_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
711
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
712
+ nk_unused_(depth_offset);
713
+ nk_unused_(active_dimensions);
714
+ // AND → doubleword popcount (vpopcntd, 3cy ALU) → vec_add (7cy DP)
715
+ // Simpler data flow than vpopcntb + vec_sum4s, and u64 accumulator holds larger counts
716
+ nk_vu8x16_t a_u8x16 = a.vu8x16;
717
+ nk_vu8x16_t b_u8x16 = b.vu8x16;
718
+ nk_vu8x16_t and_u8x16 = vec_and(a_u8x16, b_u8x16);
719
+ nk_vu64x2_t popcnt_u64x2 = vec_popcnt((nk_vu64x2_t)and_u8x16);
720
+ state->dot_count_u64x2 = vec_add(state->dot_count_u64x2, popcnt_u64x2);
721
+ }
722
+
723
+ NK_INTERNAL void nk_dot_u1x128_finalize_powervsx( //
724
+ nk_dot_u1x128_state_powervsx_t const *state_a, nk_dot_u1x128_state_powervsx_t const *state_b, //
725
+ nk_dot_u1x128_state_powervsx_t const *state_c, nk_dot_u1x128_state_powervsx_t const *state_d, //
726
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
727
+ nk_unused_(total_dimensions);
728
+ nk_vu64x2_t sum_a_u64x2 = vec_add(state_a->dot_count_u64x2,
729
+ vec_xxpermdi(state_a->dot_count_u64x2, state_a->dot_count_u64x2, 2));
730
+ nk_vu64x2_t sum_b_u64x2 = vec_add(state_b->dot_count_u64x2,
731
+ vec_xxpermdi(state_b->dot_count_u64x2, state_b->dot_count_u64x2, 2));
732
+ nk_vu64x2_t sum_c_u64x2 = vec_add(state_c->dot_count_u64x2,
733
+ vec_xxpermdi(state_c->dot_count_u64x2, state_c->dot_count_u64x2, 2));
734
+ nk_vu64x2_t sum_d_u64x2 = vec_add(state_d->dot_count_u64x2,
735
+ vec_xxpermdi(state_d->dot_count_u64x2, state_d->dot_count_u64x2, 2));
736
+ nk_vu64x2_t ab_u64x2 = vec_xxpermdi(sum_a_u64x2, sum_b_u64x2, 0);
737
+ nk_vu64x2_t cd_u64x2 = vec_xxpermdi(sum_c_u64x2, sum_d_u64x2, 0);
738
+ result->vu32x4 = vec_pack(ab_u64x2, cd_u64x2);
739
+ }
740
+
741
+ #if defined(__clang__)
742
+ #pragma clang attribute pop
743
+ #elif defined(__GNUC__)
744
+ #pragma GCC pop_options
745
+ #endif
746
+
747
+ #if defined(__cplusplus)
748
+ } // extern "C"
749
+ #endif
750
+
751
+ #endif // NK_TARGET_POWERVSX
752
+ #endif // NK_DOT_POWERVSX_H