numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -0,0 +1,752 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for POWER9 VSX.
|
|
3
|
+
* @file include/numkong/dot/powervsx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_powervsx_instructions Power9 VSX Dot Product Instructions
|
|
10
|
+
*
|
|
11
|
+
* Key Power9 VSX instructions for dot products:
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction POWER9
|
|
14
|
+
* vec_madd(a, b, c) XVMADDADP/XVMADDASP 5cy FMA: a×b+c
|
|
15
|
+
* vec_msub(a, b, c) XVMSUBADP/XVMSUBASP 5cy FMS: a×b−c
|
|
16
|
+
* vec_msum(a, b, c) VMSUMUBM/VMSUMMBM 5cy i8/u8 widening multiply-sum → i32/u32
|
|
17
|
+
* vec_msum(a, b, c) VMSUMSHM/VMSUMUHM 5cy i16/u16 widening multiply-sum → i32/u32
|
|
18
|
+
* vec_doublee(a) XVCVSPDP 3cy Widen even f32 lanes → f64x2
|
|
19
|
+
* vec_doubleo(a) XVCVSPDP (odd) 3cy Widen odd f32 lanes → f64x2
|
|
20
|
+
* vec_unpackh(a) VUPKHSB/VUPKHSH 2cy Sign-extend high half (i8→i16 or i16→i32)
|
|
21
|
+
* vec_unpackl(a) VUPKLSB/VUPKLSH 2cy Sign-extend low half (i8→i16 or i16→i32)
|
|
22
|
+
* vec_xor(a, b) VXOR/XXLXOR 1cy Bitwise XOR
|
|
23
|
+
* vec_xl(off, ptr) LXV 5cy Aligned 16-byte load
|
|
24
|
+
* vec_xl_len(ptr, len) LXVL 5cy Partial load (Power9), zero-fills tail
|
|
25
|
+
* vec_extract_fp32_from_shorth XVCVHPSP (high) 5cy f16x4 → f32x4 from high half
|
|
26
|
+
* vec_extract_fp32_from_shortl XVCVHPSP (low) 5cy f16x4 → f32x4 from low half
|
|
27
|
+
* vec_popcnt(a) VPOPCNTB/H/W/D 2cy Per-element popcount
|
|
28
|
+
* vec_sum4s(a, b) VSUM4UBS/VSUM4SBS 5cy Sum groups of 4 bytes → i32/u32
|
|
29
|
+
* vec_sums(a, b) VSUMSWS 5cy Signed i32x4 horizontal → i32 (lane 3)
|
|
30
|
+
*
|
|
31
|
+
* Power9 (POWER ISA 3.0) provides `vec_xl_len` for partial loads that zero-fill unused bytes,
|
|
32
|
+
* enabling branchless tail handling: zero × anything = zero, so partial vectors contribute
|
|
33
|
+
* no spurious terms to dot-product accumulators.
|
|
34
|
+
*
|
|
35
|
+
* @section dot_powervsx_stateful Stateful Streaming Logic
|
|
36
|
+
*
|
|
37
|
+
* For memory-optimal tiled algorithms, this file defines state structures and force-inlined
|
|
38
|
+
* `NK_INTERNAL` functions:
|
|
39
|
+
*
|
|
40
|
+
* - nk_dot_f32x2 state for f32 inputs with double-precision accumulation,
|
|
41
|
+
* - nk_dot_f64x2 state with Dot2 stable dot-products for f64 inputs,
|
|
42
|
+
* - nk_dot_bf16x8 state for bf16 inputs with f32 accumulation,
|
|
43
|
+
* - nk_dot_f16x8 state for f16 inputs with f32 accumulation,
|
|
44
|
+
* - nk_dot_i8x16 state for i8 inputs with i32 accumulation,
|
|
45
|
+
* - nk_dot_u8x16 state for u8 inputs with u32 accumulation,
|
|
46
|
+
* - nk_dot_u1x128 state for binary inputs with u64 popcount accumulation.
|
|
47
|
+
*/
|
|
48
|
+
#ifndef NK_DOT_POWERVSX_H
|
|
49
|
+
#define NK_DOT_POWERVSX_H
|
|
50
|
+
|
|
51
|
+
#if NK_TARGET_POWERVSX
|
|
52
|
+
|
|
53
|
+
#if defined(__cplusplus)
|
|
54
|
+
extern "C" {
|
|
55
|
+
#endif
|
|
56
|
+
|
|
57
|
+
#if defined(__clang__)
|
|
58
|
+
#pragma clang attribute push(__attribute__((target("power9-vector"))), apply_to = function)
|
|
59
|
+
#elif defined(__GNUC__)
|
|
60
|
+
#pragma GCC push_options
|
|
61
|
+
#pragma GCC target("power9-vector")
|
|
62
|
+
#endif
|
|
63
|
+
|
|
64
|
+
/** @brief Horizontal sum of 4 f32 lanes → scalar f32. */
|
|
65
|
+
NK_INTERNAL nk_f32_t nk_hsum_f32x4_powervsx_(nk_vf32x4_t values_f32x4) {
|
|
66
|
+
// Rotate by 8 bytes (2 floats) and add → {v[0]+v[2], v[1]+v[3], ...}
|
|
67
|
+
nk_vf32x4_t rotated_f32x4 = vec_sld(values_f32x4, values_f32x4, 8);
|
|
68
|
+
nk_vf32x4_t partial_f32x4 = vec_add(values_f32x4, rotated_f32x4);
|
|
69
|
+
// Rotate by 4 bytes (1 float) and add → {v[0]+v[1]+v[2]+v[3], ...}
|
|
70
|
+
nk_vf32x4_t shifted_f32x4 = vec_sld(partial_f32x4, partial_f32x4, 4);
|
|
71
|
+
nk_vf32x4_t total_f32x4 = vec_add(partial_f32x4, shifted_f32x4);
|
|
72
|
+
return vec_extract(total_f32x4, 0);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/** @brief Horizontal sum of 2 f64 lanes → scalar f64 via xxpermdi (1 domain crossing). */
|
|
76
|
+
NK_INTERNAL nk_f64_t nk_hsum_f64x2_powervsx_(nk_vf64x2_t values_f64x2) {
|
|
77
|
+
nk_vf64x2_t swapped_f64x2 = vec_xxpermdi(values_f64x2, values_f64x2, 2);
|
|
78
|
+
nk_vf64x2_t sum_f64x2 = vec_add(values_f64x2, swapped_f64x2);
|
|
79
|
+
return vec_extract(sum_f64x2, 0);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/** @brief Horizontal sum of 4 signed i32 lanes → scalar i32. */
|
|
83
|
+
NK_INTERNAL nk_i32_t nk_hsum_i32x4_powervsx_(nk_vi32x4_t values_i32x4) {
|
|
84
|
+
// vec_sums reduces i32x4 → i32 in lane 3 of the result
|
|
85
|
+
nk_vi32x4_t zero_i32x4 = vec_splats((nk_i32_t)0);
|
|
86
|
+
nk_vi32x4_t sums_i32x4 = vec_sums(values_i32x4, zero_i32x4);
|
|
87
|
+
return vec_extract(sums_i32x4, 3);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
/** @brief Horizontal sum of 4 unsigned u32 lanes → scalar u32. */
|
|
91
|
+
NK_INTERNAL nk_u32_t nk_hsum_u32x4_powervsx_(nk_vu32x4_t values_u32x4) {
|
|
92
|
+
// Rotate by 8 bytes (2 ints) and add → {v[0]+v[2], v[1]+v[3], ...}
|
|
93
|
+
nk_vu32x4_t rotated_u32x4 = vec_sld(values_u32x4, values_u32x4, 8);
|
|
94
|
+
nk_vu32x4_t partial_u32x4 = vec_add(values_u32x4, rotated_u32x4);
|
|
95
|
+
// Rotate by 4 bytes (1 int) and add → {v[0]+v[1]+v[2]+v[3], ...}
|
|
96
|
+
nk_vu32x4_t shifted_u32x4 = vec_sld(partial_u32x4, partial_u32x4, 4);
|
|
97
|
+
nk_vu32x4_t total_u32x4 = vec_add(partial_u32x4, shifted_u32x4);
|
|
98
|
+
return vec_extract(total_u32x4, 0);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
/** @brief Horizontal sum of 2 unsigned u64 lanes → scalar u64 via xxpermdi. */
|
|
102
|
+
NK_INTERNAL nk_u64_t nk_hsum_u64x2_powervsx_(nk_vu64x2_t values_u64x2) {
|
|
103
|
+
nk_vu64x2_t swapped_u64x2 = vec_xxpermdi(values_u64x2, values_u64x2, 2);
|
|
104
|
+
nk_vu64x2_t sum_u64x2 = vec_add(values_u64x2, swapped_u64x2);
|
|
105
|
+
return vec_extract(sum_u64x2, 0);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/** @brief Compensated horizontal sum of 2 f64 lanes via TwoSum. */
|
|
109
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_powervsx_(nk_vf64x2_t sum_f64x2, nk_vf64x2_t compensation_f64x2) {
|
|
110
|
+
// TwoSum merge of sum + compensation (2-wide)
|
|
111
|
+
nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, compensation_f64x2);
|
|
112
|
+
nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
|
|
113
|
+
nk_vf64x2_t rounding_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
114
|
+
vec_sub(compensation_f64x2, virtual_addend_f64x2));
|
|
115
|
+
// Scalar TwoSum 2 → 1
|
|
116
|
+
nk_f64_t lower_sum = vec_extract(tentative_sum_f64x2, 0);
|
|
117
|
+
nk_f64_t upper_sum = vec_extract(tentative_sum_f64x2, 1);
|
|
118
|
+
nk_f64_t lower_error = vec_extract(rounding_error_f64x2, 0);
|
|
119
|
+
nk_f64_t upper_error = vec_extract(rounding_error_f64x2, 1);
|
|
120
|
+
nk_f64_t tentative_sum = lower_sum + upper_sum;
|
|
121
|
+
nk_f64_t virtual_addend = tentative_sum - lower_sum;
|
|
122
|
+
nk_f64_t rounding_error = (lower_sum - (tentative_sum - virtual_addend)) + (upper_sum - virtual_addend);
|
|
123
|
+
return tentative_sum + (lower_error + upper_error + rounding_error);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
#pragma region F32 and F64 Floats
|
|
127
|
+
|
|
128
|
+
NK_PUBLIC void nk_dot_f32_powervsx(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
129
|
+
nk_f64_t *result) {
|
|
130
|
+
// Upcast f32 → f64 for accumulation via vec_doublee (even lanes) and vec_doubleo (odd lanes)
|
|
131
|
+
nk_vf64x2_t sum_even_f64x2 = vec_splats((nk_f64_t)0);
|
|
132
|
+
nk_vf64x2_t sum_odd_f64x2 = vec_splats((nk_f64_t)0);
|
|
133
|
+
nk_vf32x4_t a_f32x4, b_f32x4;
|
|
134
|
+
nk_size_t tail_bytes;
|
|
135
|
+
|
|
136
|
+
nk_dot_f32_powervsx_cycle:
|
|
137
|
+
if (count_scalars < 4) {
|
|
138
|
+
tail_bytes = count_scalars * sizeof(nk_f32_t);
|
|
139
|
+
a_f32x4 = vec_xl_len((nk_f32_t *)a_scalars, tail_bytes);
|
|
140
|
+
b_f32x4 = vec_xl_len((nk_f32_t *)b_scalars, tail_bytes);
|
|
141
|
+
count_scalars = 0;
|
|
142
|
+
}
|
|
143
|
+
else {
|
|
144
|
+
a_f32x4 = vec_xl(0, a_scalars);
|
|
145
|
+
b_f32x4 = vec_xl(0, b_scalars);
|
|
146
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
// Widen even/odd f32 lanes → f64x2, then FMA
|
|
150
|
+
nk_vf64x2_t a_even_f64x2 = vec_doublee(a_f32x4);
|
|
151
|
+
nk_vf64x2_t b_even_f64x2 = vec_doublee(b_f32x4);
|
|
152
|
+
nk_vf64x2_t a_odd_f64x2 = vec_doubleo(a_f32x4);
|
|
153
|
+
nk_vf64x2_t b_odd_f64x2 = vec_doubleo(b_f32x4);
|
|
154
|
+
sum_even_f64x2 = vec_madd(a_even_f64x2, b_even_f64x2, sum_even_f64x2);
|
|
155
|
+
sum_odd_f64x2 = vec_madd(a_odd_f64x2, b_odd_f64x2, sum_odd_f64x2);
|
|
156
|
+
|
|
157
|
+
if (count_scalars) goto nk_dot_f32_powervsx_cycle;
|
|
158
|
+
// Combine even and odd accumulators → final scalar
|
|
159
|
+
nk_vf64x2_t total_f64x2 = vec_add(sum_even_f64x2, sum_odd_f64x2);
|
|
160
|
+
*result = nk_hsum_f64x2_powervsx_(total_f64x2);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
NK_PUBLIC void nk_dot_f64_powervsx(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
164
|
+
nk_f64_t *result) {
|
|
165
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
|
|
166
|
+
nk_vf64x2_t sum_f64x2 = vec_splats((nk_f64_t)0);
|
|
167
|
+
nk_vf64x2_t compensation_f64x2 = vec_splats((nk_f64_t)0);
|
|
168
|
+
nk_vf64x2_t a_f64x2, b_f64x2;
|
|
169
|
+
nk_size_t tail_bytes;
|
|
170
|
+
|
|
171
|
+
nk_dot_f64_powervsx_cycle:
|
|
172
|
+
if (count_scalars < 2) {
|
|
173
|
+
tail_bytes = count_scalars * sizeof(nk_f64_t);
|
|
174
|
+
a_f64x2 = vec_xl_len((nk_f64_t *)a_scalars, tail_bytes);
|
|
175
|
+
b_f64x2 = vec_xl_len((nk_f64_t *)b_scalars, tail_bytes);
|
|
176
|
+
count_scalars = 0;
|
|
177
|
+
}
|
|
178
|
+
else {
|
|
179
|
+
a_f64x2 = vec_xl(0, a_scalars);
|
|
180
|
+
b_f64x2 = vec_xl(0, b_scalars);
|
|
181
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// TwoProd: product = a * b, error = msub(a, b, product) captures rounding error
|
|
185
|
+
nk_vf64x2_t product_f64x2 = vec_mul(a_f64x2, b_f64x2);
|
|
186
|
+
nk_vf64x2_t product_error_f64x2 = vec_msub(a_f64x2, b_f64x2, product_f64x2);
|
|
187
|
+
// TwoSum: (t, q) = TwoSum(sum, product) where t = sum + product rounded, q = error
|
|
188
|
+
nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, product_f64x2);
|
|
189
|
+
nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
|
|
190
|
+
nk_vf64x2_t sum_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
191
|
+
vec_sub(product_f64x2, virtual_addend_f64x2));
|
|
192
|
+
// Update: sum = t, compensation += q + r
|
|
193
|
+
sum_f64x2 = tentative_sum_f64x2;
|
|
194
|
+
compensation_f64x2 = vec_add(compensation_f64x2, vec_add(sum_error_f64x2, product_error_f64x2));
|
|
195
|
+
|
|
196
|
+
if (count_scalars) goto nk_dot_f64_powervsx_cycle;
|
|
197
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
198
|
+
*result = nk_dot_stable_sum_f64x2_powervsx_(sum_f64x2, compensation_f64x2);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
#pragma endregion F32 and F64 Floats
|
|
202
|
+
#pragma region F16 and BF16 Floats
|
|
203
|
+
|
|
204
|
+
NK_PUBLIC void nk_dot_bf16_powervsx(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
205
|
+
nk_f32_t *result) {
|
|
206
|
+
// bf16 → f32 via mergeh/mergel with zero: shift 16 bits into f32 upper half
|
|
207
|
+
nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
|
|
208
|
+
nk_vf32x4_t sum_f32x4 = vec_splats((nk_f32_t)0);
|
|
209
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
210
|
+
nk_size_t tail_bytes;
|
|
211
|
+
|
|
212
|
+
nk_dot_bf16_powervsx_cycle:
|
|
213
|
+
if (count_scalars < 8) {
|
|
214
|
+
tail_bytes = count_scalars * sizeof(nk_bf16_t);
|
|
215
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a_scalars, tail_bytes);
|
|
216
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b_scalars, tail_bytes);
|
|
217
|
+
count_scalars = 0;
|
|
218
|
+
}
|
|
219
|
+
else {
|
|
220
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a_scalars);
|
|
221
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b_scalars);
|
|
222
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
// Convert bf16 → f32: merge with zero puts bf16 bits in upper 16 of each f32
|
|
226
|
+
nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
|
|
227
|
+
nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
|
|
228
|
+
nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
|
|
229
|
+
nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
|
|
230
|
+
sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, sum_f32x4);
|
|
231
|
+
sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, sum_f32x4);
|
|
232
|
+
|
|
233
|
+
if (count_scalars) goto nk_dot_bf16_powervsx_cycle;
|
|
234
|
+
*result = nk_hsum_f32x4_powervsx_(sum_f32x4);
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
NK_PUBLIC void nk_dot_f16_powervsx(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
238
|
+
nk_f32_t *result) {
|
|
239
|
+
// f16 → f32 via vec_extract_fp32_from_shorth/shortl (Power9 XVCVHPSP)
|
|
240
|
+
nk_vf32x4_t sum_f32x4 = vec_splats((nk_f32_t)0);
|
|
241
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
242
|
+
nk_size_t tail_bytes;
|
|
243
|
+
|
|
244
|
+
nk_dot_f16_powervsx_cycle:
|
|
245
|
+
if (count_scalars < 8) {
|
|
246
|
+
tail_bytes = count_scalars * sizeof(nk_f16_t);
|
|
247
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a_scalars, tail_bytes);
|
|
248
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b_scalars, tail_bytes);
|
|
249
|
+
count_scalars = 0;
|
|
250
|
+
}
|
|
251
|
+
else {
|
|
252
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a_scalars);
|
|
253
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b_scalars);
|
|
254
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// Convert f16 → f32 via hardware XVCVHPSP
|
|
258
|
+
nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
|
|
259
|
+
nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
|
|
260
|
+
nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
|
|
261
|
+
nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
|
|
262
|
+
sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, sum_f32x4);
|
|
263
|
+
sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, sum_f32x4);
|
|
264
|
+
|
|
265
|
+
if (count_scalars) goto nk_dot_f16_powervsx_cycle;
|
|
266
|
+
*result = nk_hsum_f32x4_powervsx_(sum_f32x4);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
#pragma endregion F16 and BF16 Floats
|
|
270
|
+
#pragma region I8 and U8 Integers
|
|
271
|
+
|
|
272
|
+
NK_PUBLIC void nk_dot_i8_powervsx(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
273
|
+
nk_i32_t *result) {
|
|
274
|
+
// Algebraic transform for i8×i8 using VMSUMMBM (i8×u8 → i32):
|
|
275
|
+
// b' = b ⊕ 0x80 (reinterpret signed as unsigned)
|
|
276
|
+
// a·b = a·b' − 128·Σa
|
|
277
|
+
// Σ(a+128) accumulated via VSUM4UBS; correction applied after loop.
|
|
278
|
+
// Tail handling is free: vec_xl_len zero-fills unused lanes.
|
|
279
|
+
// - Product: 0 × (0⊕0x80) = 0 → no spurious contribution
|
|
280
|
+
// - Correction: (0⊕0x80) = 128 in sum_a_biased, compensated by count_padded
|
|
281
|
+
nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
|
|
282
|
+
nk_vi32x4_t accumulator_i32x4 = vec_splats((nk_i32_t)0);
|
|
283
|
+
nk_vu32x4_t sum_a_biased_u32x4 = vec_splats((nk_u32_t)0);
|
|
284
|
+
nk_size_t count_padded = ((count_scalars + 15) / 16) * 16;
|
|
285
|
+
nk_vi8x16_t a_i8x16;
|
|
286
|
+
nk_vu8x16_t b_biased_u8x16;
|
|
287
|
+
nk_size_t tail_bytes;
|
|
288
|
+
|
|
289
|
+
nk_dot_i8_powervsx_cycle:
|
|
290
|
+
if (count_scalars < 16) {
|
|
291
|
+
tail_bytes = count_scalars * sizeof(nk_i8_t);
|
|
292
|
+
a_i8x16 = vec_xl_len((nk_i8_t *)a_scalars, tail_bytes);
|
|
293
|
+
b_biased_u8x16 = vec_xor(vec_xl_len((nk_u8_t *)b_scalars, tail_bytes), bias_u8x16);
|
|
294
|
+
count_scalars = 0;
|
|
295
|
+
}
|
|
296
|
+
else {
|
|
297
|
+
a_i8x16 = vec_xl(0, a_scalars);
|
|
298
|
+
b_biased_u8x16 = vec_xor(vec_xl(0, (nk_u8_t *)b_scalars), bias_u8x16);
|
|
299
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
// VMSUMMBM: i8 × u8 → i32 (16 products per instruction)
|
|
303
|
+
accumulator_i32x4 = vec_msum(a_i8x16, b_biased_u8x16, accumulator_i32x4);
|
|
304
|
+
// VSUM4UBS: accumulate Σ(a+128) as unsigned (independent chain, good ILP)
|
|
305
|
+
sum_a_biased_u32x4 = vec_sum4s(vec_xor((nk_vu8x16_t)a_i8x16, bias_u8x16), sum_a_biased_u32x4);
|
|
306
|
+
|
|
307
|
+
if (count_scalars) goto nk_dot_i8_powervsx_cycle;
|
|
308
|
+
|
|
309
|
+
// Correction: a·b = biased_dot − 128·Σa = biased_dot − 128·(Σ(a+128) − 128·count_padded)
|
|
310
|
+
nk_i32_t biased_dot = nk_hsum_i32x4_powervsx_(accumulator_i32x4);
|
|
311
|
+
nk_i64_t correction = 128LL * (nk_i64_t)nk_hsum_u32x4_powervsx_(sum_a_biased_u32x4) -
|
|
312
|
+
16384LL * (nk_i64_t)count_padded;
|
|
313
|
+
*result = (nk_i32_t)((nk_i64_t)biased_dot - correction);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
NK_PUBLIC void nk_dot_u8_powervsx(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
317
|
+
nk_u32_t *result) {
|
|
318
|
+
// vec_msum: multiply u8×u8 pairs and accumulate 16 products → 4 u32 lanes per call
|
|
319
|
+
nk_vu32x4_t accumulator_u32x4 = vec_splats((nk_u32_t)0);
|
|
320
|
+
nk_vu8x16_t a_u8x16, b_u8x16;
|
|
321
|
+
nk_size_t tail_bytes;
|
|
322
|
+
|
|
323
|
+
nk_dot_u8_powervsx_cycle:
|
|
324
|
+
if (count_scalars < 16) {
|
|
325
|
+
tail_bytes = count_scalars * sizeof(nk_u8_t);
|
|
326
|
+
a_u8x16 = vec_xl_len((nk_u8_t *)a_scalars, tail_bytes);
|
|
327
|
+
b_u8x16 = vec_xl_len((nk_u8_t *)b_scalars, tail_bytes);
|
|
328
|
+
count_scalars = 0;
|
|
329
|
+
}
|
|
330
|
+
else {
|
|
331
|
+
a_u8x16 = vec_xl(0, a_scalars);
|
|
332
|
+
b_u8x16 = vec_xl(0, b_scalars);
|
|
333
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Unsigned × unsigned multiply-sum: 16 u8 products accumulated into 4 u32 lanes
|
|
337
|
+
accumulator_u32x4 = vec_msum(a_u8x16, b_u8x16, accumulator_u32x4);
|
|
338
|
+
|
|
339
|
+
if (count_scalars) goto nk_dot_u8_powervsx_cycle;
|
|
340
|
+
*result = nk_hsum_u32x4_powervsx_(accumulator_u32x4);
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
#pragma endregion I8 and U8 Integers
|
|
344
|
+
#pragma region Binary
|
|
345
|
+
|
|
346
|
+
NK_PUBLIC void nk_dot_u1_powervsx(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
347
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
348
|
+
nk_vu64x2_t accumulator_u64x2 = vec_splats((nk_u64_t)0);
|
|
349
|
+
nk_vu8x16_t a_u8x16, b_u8x16;
|
|
350
|
+
|
|
351
|
+
nk_dot_u1_powervsx_cycle:
|
|
352
|
+
if (n_bytes < 16) {
|
|
353
|
+
a_u8x16 = vec_xl_len((nk_u8_t *)a, n_bytes);
|
|
354
|
+
b_u8x16 = vec_xl_len((nk_u8_t *)b, n_bytes);
|
|
355
|
+
n_bytes = 0;
|
|
356
|
+
}
|
|
357
|
+
else {
|
|
358
|
+
a_u8x16 = vec_xl(0, (nk_u8_t const *)a);
|
|
359
|
+
b_u8x16 = vec_xl(0, (nk_u8_t const *)b);
|
|
360
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
// AND → doubleword popcount (vpopcntd) → accumulate u64 lanes
|
|
364
|
+
nk_vu8x16_t and_u8x16 = vec_and(a_u8x16, b_u8x16);
|
|
365
|
+
nk_vu64x2_t popcnt_u64x2 = vec_popcnt((nk_vu64x2_t)and_u8x16);
|
|
366
|
+
accumulator_u64x2 = vec_add(accumulator_u64x2, popcnt_u64x2);
|
|
367
|
+
|
|
368
|
+
if (n_bytes) goto nk_dot_u1_powervsx_cycle;
|
|
369
|
+
*result = (nk_u32_t)nk_hsum_u64x2_powervsx_(accumulator_u64x2);
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
#pragma endregion Binary
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* @brief Running state for 128-bit dot accumulation over f32 scalars on Power VSX.
|
|
376
|
+
*
|
|
377
|
+
* Processes 2 f32 values at a time, upcasting to f64 for accumulation to avoid
|
|
378
|
+
* catastrophic cancellation in long reductions.
|
|
379
|
+
*/
|
|
380
|
+
typedef struct nk_dot_f32x2_state_powervsx_t {
|
|
381
|
+
nk_vf64x2_t sum_f64x2;
|
|
382
|
+
} nk_dot_f32x2_state_powervsx_t;
|
|
383
|
+
|
|
384
|
+
NK_INTERNAL void nk_dot_f32x2_init_powervsx(nk_dot_f32x2_state_powervsx_t *state) {
|
|
385
|
+
state->sum_f64x2 = vec_splats((nk_f64_t)0);
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
NK_INTERNAL void nk_dot_f32x2_update_powervsx(nk_dot_f32x2_state_powervsx_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
389
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
390
|
+
nk_unused_(depth_offset);
|
|
391
|
+
nk_unused_(active_dimensions);
|
|
392
|
+
// Load 8 bytes (2 f32s) into a vector register, zero-filling the upper 8 bytes
|
|
393
|
+
nk_vf32x4_t a_f32x4 = vec_xl_len((nk_f32_t *)a.f32s, 8);
|
|
394
|
+
nk_vf32x4_t b_f32x4 = vec_xl_len((nk_f32_t *)b.f32s, 8);
|
|
395
|
+
// Widen even lanes (the two f32 values) → f64x2
|
|
396
|
+
nk_vf64x2_t a_f64x2 = vec_doublee(a_f32x4);
|
|
397
|
+
nk_vf64x2_t b_f64x2 = vec_doublee(b_f32x4);
|
|
398
|
+
// Permute to get {lane0, lane2} → {a[0], a[1]} as f64x2
|
|
399
|
+
a_f64x2 = vec_xxpermdi(a_f64x2, vec_doubleo(a_f32x4), 0);
|
|
400
|
+
b_f64x2 = vec_xxpermdi(b_f64x2, vec_doubleo(b_f32x4), 0);
|
|
401
|
+
state->sum_f64x2 = vec_madd(a_f64x2, b_f64x2, state->sum_f64x2);
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
NK_INTERNAL void nk_dot_f32x2_finalize_powervsx( //
|
|
405
|
+
nk_dot_f32x2_state_powervsx_t const *state_a, nk_dot_f32x2_state_powervsx_t const *state_b, //
|
|
406
|
+
nk_dot_f32x2_state_powervsx_t const *state_c, nk_dot_f32x2_state_powervsx_t const *state_d, //
|
|
407
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
408
|
+
nk_unused_(total_dimensions);
|
|
409
|
+
nk_vf64x2_t sum_a_f64x2 = vec_add(state_a->sum_f64x2, vec_xxpermdi(state_a->sum_f64x2, state_a->sum_f64x2, 2));
|
|
410
|
+
nk_vf64x2_t sum_b_f64x2 = vec_add(state_b->sum_f64x2, vec_xxpermdi(state_b->sum_f64x2, state_b->sum_f64x2, 2));
|
|
411
|
+
nk_vf64x2_t sum_c_f64x2 = vec_add(state_c->sum_f64x2, vec_xxpermdi(state_c->sum_f64x2, state_c->sum_f64x2, 2));
|
|
412
|
+
nk_vf64x2_t sum_d_f64x2 = vec_add(state_d->sum_f64x2, vec_xxpermdi(state_d->sum_f64x2, state_d->sum_f64x2, 2));
|
|
413
|
+
result->vf64x2s[0] = vec_xxpermdi(sum_a_f64x2, sum_b_f64x2, 0);
|
|
414
|
+
result->vf64x2s[1] = vec_xxpermdi(sum_c_f64x2, sum_d_f64x2, 0);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
/**
|
|
418
|
+
* @brief Running state for 128-bit dot accumulation over f64 scalars on Power VSX.
|
|
419
|
+
*
|
|
420
|
+
* Uses the Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product.
|
|
421
|
+
*/
|
|
422
|
+
typedef struct nk_dot_f64x2_state_powervsx_t {
|
|
423
|
+
nk_vf64x2_t sum_f64x2;
|
|
424
|
+
nk_vf64x2_t compensation_f64x2;
|
|
425
|
+
} nk_dot_f64x2_state_powervsx_t;
|
|
426
|
+
|
|
427
|
+
NK_INTERNAL void nk_dot_f64x2_init_powervsx(nk_dot_f64x2_state_powervsx_t *state) {
|
|
428
|
+
state->sum_f64x2 = vec_splats((nk_f64_t)0);
|
|
429
|
+
state->compensation_f64x2 = vec_splats((nk_f64_t)0);
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
NK_INTERNAL void nk_dot_f64x2_update_powervsx(nk_dot_f64x2_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
433
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
434
|
+
nk_unused_(depth_offset);
|
|
435
|
+
nk_unused_(active_dimensions);
|
|
436
|
+
nk_vf64x2_t sum_f64x2 = state->sum_f64x2;
|
|
437
|
+
nk_vf64x2_t compensation_f64x2 = state->compensation_f64x2;
|
|
438
|
+
nk_vf64x2_t a_f64x2 = a.vf64x2;
|
|
439
|
+
nk_vf64x2_t b_f64x2 = b.vf64x2;
|
|
440
|
+
|
|
441
|
+
// TwoProd: product = a × b, error = msub(a, b, product) captures rounding error
|
|
442
|
+
nk_vf64x2_t product_f64x2 = vec_mul(a_f64x2, b_f64x2);
|
|
443
|
+
nk_vf64x2_t product_error_f64x2 = vec_msub(a_f64x2, b_f64x2, product_f64x2);
|
|
444
|
+
|
|
445
|
+
// TwoSum: (t, q) = TwoSum(sum, product) where t = sum + product rounded, q = error
|
|
446
|
+
nk_vf64x2_t tentative_sum_f64x2 = vec_add(sum_f64x2, product_f64x2);
|
|
447
|
+
nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, sum_f64x2);
|
|
448
|
+
nk_vf64x2_t sum_error_f64x2 = vec_add(vec_sub(sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
449
|
+
vec_sub(product_f64x2, virtual_addend_f64x2));
|
|
450
|
+
|
|
451
|
+
// Update: sum = t, compensation += q + r
|
|
452
|
+
state->sum_f64x2 = tentative_sum_f64x2;
|
|
453
|
+
state->compensation_f64x2 = vec_add(compensation_f64x2, vec_add(sum_error_f64x2, product_error_f64x2));
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
NK_INTERNAL void nk_dot_f64x2_finalize_powervsx( //
|
|
457
|
+
nk_dot_f64x2_state_powervsx_t const *state_a, nk_dot_f64x2_state_powervsx_t const *state_b, //
|
|
458
|
+
nk_dot_f64x2_state_powervsx_t const *state_c, nk_dot_f64x2_state_powervsx_t const *state_d, //
|
|
459
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
460
|
+
nk_unused_(total_dimensions);
|
|
461
|
+
// Compensated horizontal reduction preserving Dot2 error tracking per state
|
|
462
|
+
result->f64s[0] = nk_dot_stable_sum_f64x2_powervsx_(state_a->sum_f64x2, state_a->compensation_f64x2);
|
|
463
|
+
result->f64s[1] = nk_dot_stable_sum_f64x2_powervsx_(state_b->sum_f64x2, state_b->compensation_f64x2);
|
|
464
|
+
result->f64s[2] = nk_dot_stable_sum_f64x2_powervsx_(state_c->sum_f64x2, state_c->compensation_f64x2);
|
|
465
|
+
result->f64s[3] = nk_dot_stable_sum_f64x2_powervsx_(state_d->sum_f64x2, state_d->compensation_f64x2);
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
/**
|
|
469
|
+
* @brief Running state for 128-bit dot accumulation over bf16 scalars on Power VSX.
|
|
470
|
+
*
|
|
471
|
+
* Processes 8 bf16 values at a time (128 bits), converting to f32 via vec_mergeh/mergel
|
|
472
|
+
* with zero for accumulation.
|
|
473
|
+
*/
|
|
474
|
+
typedef struct nk_dot_bf16x8_state_powervsx_t {
|
|
475
|
+
nk_vf32x4_t sum_f32x4;
|
|
476
|
+
} nk_dot_bf16x8_state_powervsx_t;
|
|
477
|
+
|
|
478
|
+
NK_INTERNAL void nk_dot_bf16x8_init_powervsx(nk_dot_bf16x8_state_powervsx_t *state) {
|
|
479
|
+
state->sum_f32x4 = vec_splats((nk_f32_t)0);
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
NK_INTERNAL void nk_dot_bf16x8_update_powervsx(nk_dot_bf16x8_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
483
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
484
|
+
nk_unused_(depth_offset);
|
|
485
|
+
nk_unused_(active_dimensions);
|
|
486
|
+
// Convert bf16 → f32 inline: merge with zero puts bf16 bits in upper 16 of each f32
|
|
487
|
+
nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
|
|
488
|
+
nk_vu16x8_t a_u16x8 = a.vu16x8;
|
|
489
|
+
nk_vu16x8_t b_u16x8 = b.vu16x8;
|
|
490
|
+
nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
|
|
491
|
+
nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
|
|
492
|
+
nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
|
|
493
|
+
nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
|
|
494
|
+
state->sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, state->sum_f32x4);
|
|
495
|
+
state->sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, state->sum_f32x4);
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
NK_INTERNAL void nk_dot_bf16x8_finalize_powervsx( //
|
|
499
|
+
nk_dot_bf16x8_state_powervsx_t const *state_a, nk_dot_bf16x8_state_powervsx_t const *state_b, //
|
|
500
|
+
nk_dot_bf16x8_state_powervsx_t const *state_c, nk_dot_bf16x8_state_powervsx_t const *state_d, //
|
|
501
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
502
|
+
nk_unused_(total_dimensions);
|
|
503
|
+
nk_vf32x4_t a_f32x4 = state_a->sum_f32x4, b_f32x4 = state_b->sum_f32x4, c_f32x4 = state_c->sum_f32x4,
|
|
504
|
+
d_f32x4 = state_d->sum_f32x4;
|
|
505
|
+
nk_vf32x4_t transpose_ab_low_f32x4 = vec_mergeh(a_f32x4, b_f32x4);
|
|
506
|
+
nk_vf32x4_t transpose_cd_low_f32x4 = vec_mergeh(c_f32x4, d_f32x4);
|
|
507
|
+
nk_vf32x4_t transpose_ab_high_f32x4 = vec_mergel(a_f32x4, b_f32x4);
|
|
508
|
+
nk_vf32x4_t transpose_cd_high_f32x4 = vec_mergel(c_f32x4, d_f32x4);
|
|
509
|
+
nk_vf32x4_t sum_lane0_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
|
|
510
|
+
(nk_vu64x2_t)transpose_cd_low_f32x4, 0);
|
|
511
|
+
nk_vf32x4_t sum_lane1_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
|
|
512
|
+
(nk_vu64x2_t)transpose_cd_low_f32x4, 3);
|
|
513
|
+
nk_vf32x4_t sum_lane2_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
|
|
514
|
+
(nk_vu64x2_t)transpose_cd_high_f32x4, 0);
|
|
515
|
+
nk_vf32x4_t sum_lane3_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
|
|
516
|
+
(nk_vu64x2_t)transpose_cd_high_f32x4, 3);
|
|
517
|
+
result->vf32x4 = vec_add(vec_add(sum_lane0_f32x4, sum_lane1_f32x4), vec_add(sum_lane2_f32x4, sum_lane3_f32x4));
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
/**
|
|
521
|
+
* @brief Running state for 128-bit dot accumulation over f16 scalars on Power VSX.
|
|
522
|
+
*
|
|
523
|
+
* Processes 8 f16 values at a time (128 bits), converting to f32 via
|
|
524
|
+
* vec_extract_fp32_from_shorth/shortl for accumulation.
|
|
525
|
+
*/
|
|
526
|
+
typedef struct nk_dot_f16x8_state_powervsx_t {
|
|
527
|
+
nk_vf32x4_t sum_f32x4;
|
|
528
|
+
} nk_dot_f16x8_state_powervsx_t;
|
|
529
|
+
|
|
530
|
+
NK_INTERNAL void nk_dot_f16x8_init_powervsx(nk_dot_f16x8_state_powervsx_t *state) {
|
|
531
|
+
state->sum_f32x4 = vec_splats((nk_f32_t)0);
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
NK_INTERNAL void nk_dot_f16x8_update_powervsx(nk_dot_f16x8_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
535
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
536
|
+
nk_unused_(depth_offset);
|
|
537
|
+
nk_unused_(active_dimensions);
|
|
538
|
+
// Convert f16 → f32 via hardware XVCVHPSP
|
|
539
|
+
nk_vu16x8_t a_u16x8 = a.vu16x8;
|
|
540
|
+
nk_vu16x8_t b_u16x8 = b.vu16x8;
|
|
541
|
+
nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
|
|
542
|
+
nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
|
|
543
|
+
nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
|
|
544
|
+
nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
|
|
545
|
+
state->sum_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, state->sum_f32x4);
|
|
546
|
+
state->sum_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, state->sum_f32x4);
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
NK_INTERNAL void nk_dot_f16x8_finalize_powervsx( //
|
|
550
|
+
nk_dot_f16x8_state_powervsx_t const *state_a, nk_dot_f16x8_state_powervsx_t const *state_b, //
|
|
551
|
+
nk_dot_f16x8_state_powervsx_t const *state_c, nk_dot_f16x8_state_powervsx_t const *state_d, //
|
|
552
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
553
|
+
nk_unused_(total_dimensions);
|
|
554
|
+
nk_vf32x4_t a_f32x4 = state_a->sum_f32x4, b_f32x4 = state_b->sum_f32x4, c_f32x4 = state_c->sum_f32x4,
|
|
555
|
+
d_f32x4 = state_d->sum_f32x4;
|
|
556
|
+
nk_vf32x4_t transpose_ab_low_f32x4 = vec_mergeh(a_f32x4, b_f32x4);
|
|
557
|
+
nk_vf32x4_t transpose_cd_low_f32x4 = vec_mergeh(c_f32x4, d_f32x4);
|
|
558
|
+
nk_vf32x4_t transpose_ab_high_f32x4 = vec_mergel(a_f32x4, b_f32x4);
|
|
559
|
+
nk_vf32x4_t transpose_cd_high_f32x4 = vec_mergel(c_f32x4, d_f32x4);
|
|
560
|
+
nk_vf32x4_t sum_lane0_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
|
|
561
|
+
(nk_vu64x2_t)transpose_cd_low_f32x4, 0);
|
|
562
|
+
nk_vf32x4_t sum_lane1_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_f32x4,
|
|
563
|
+
(nk_vu64x2_t)transpose_cd_low_f32x4, 3);
|
|
564
|
+
nk_vf32x4_t sum_lane2_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
|
|
565
|
+
(nk_vu64x2_t)transpose_cd_high_f32x4, 0);
|
|
566
|
+
nk_vf32x4_t sum_lane3_f32x4 = (nk_vf32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_f32x4,
|
|
567
|
+
(nk_vu64x2_t)transpose_cd_high_f32x4, 3);
|
|
568
|
+
result->vf32x4 = vec_add(vec_add(sum_lane0_f32x4, sum_lane1_f32x4), vec_add(sum_lane2_f32x4, sum_lane3_f32x4));
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
/**
|
|
572
|
+
* @brief Running state for 128-bit dot accumulation over i8 scalars on Power VSX.
|
|
573
|
+
*
|
|
574
|
+
* Algebraic transform: a·b = a·(b⊕0x80) − 128·Σa. Uses VMSUMMBM (i8×u8 → i32) for the biased
|
|
575
|
+
* product. Correction is applied at finalize using precomputed column sums from the compensated
|
|
576
|
+
* macro infrastructure.
|
|
577
|
+
*/
|
|
578
|
+
typedef struct nk_dot_i8x16_state_powervsx_t {
|
|
579
|
+
nk_vi32x4_t biased_sum_i32x4;
|
|
580
|
+
} nk_dot_i8x16_state_powervsx_t;
|
|
581
|
+
|
|
582
|
+
NK_INTERNAL void nk_dot_i8x16_init_powervsx(nk_dot_i8x16_state_powervsx_t *state) {
|
|
583
|
+
state->biased_sum_i32x4 = vec_splats((nk_i32_t)0);
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
NK_INTERNAL void nk_dot_i8x16_update_powervsx(nk_dot_i8x16_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
587
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
588
|
+
nk_unused_(depth_offset);
|
|
589
|
+
nk_unused_(active_dimensions);
|
|
590
|
+
// VMSUMMBM(b, a⊕0x80) = Σ(b_i · (a_i+128)) = a·b + 128·Σb
|
|
591
|
+
// Swapping operands: b in signed slot, biased a in unsigned slot.
|
|
592
|
+
// Correction −128·Σb uses precomputed B column sums from the compensated macro.
|
|
593
|
+
nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
|
|
594
|
+
nk_vu8x16_t a_biased_u8x16 = vec_xor(a.vu8x16, bias_u8x16);
|
|
595
|
+
state->biased_sum_i32x4 = vec_msum(b.vi8x16, a_biased_u8x16, state->biased_sum_i32x4);
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
NK_INTERNAL void nk_dot_i8x16_finalize_powervsx( //
|
|
599
|
+
nk_dot_i8x16_state_powervsx_t const *state_a, nk_dot_i8x16_state_powervsx_t const *state_b, //
|
|
600
|
+
nk_dot_i8x16_state_powervsx_t const *state_c, nk_dot_i8x16_state_powervsx_t const *state_d, //
|
|
601
|
+
nk_size_t total_dimensions, //
|
|
602
|
+
nk_i32_t a_sum, nk_b128_vec_t b_sums, nk_b128_vec_t *result) {
|
|
603
|
+
nk_unused_(total_dimensions);
|
|
604
|
+
nk_unused_(a_sum);
|
|
605
|
+
|
|
606
|
+
// Transpose-reduce biased products across 4 accumulators → one i32x4
|
|
607
|
+
nk_vi32x4_t a_i32x4 = state_a->biased_sum_i32x4, b_i32x4 = state_b->biased_sum_i32x4,
|
|
608
|
+
c_i32x4 = state_c->biased_sum_i32x4, d_i32x4 = state_d->biased_sum_i32x4;
|
|
609
|
+
nk_vi32x4_t transpose_ab_low_i32x4 = vec_mergeh(a_i32x4, b_i32x4);
|
|
610
|
+
nk_vi32x4_t transpose_cd_low_i32x4 = vec_mergeh(c_i32x4, d_i32x4);
|
|
611
|
+
nk_vi32x4_t transpose_ab_high_i32x4 = vec_mergel(a_i32x4, b_i32x4);
|
|
612
|
+
nk_vi32x4_t transpose_cd_high_i32x4 = vec_mergel(c_i32x4, d_i32x4);
|
|
613
|
+
nk_vi32x4_t sum_lane0_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_i32x4,
|
|
614
|
+
(nk_vu64x2_t)transpose_cd_low_i32x4, 0);
|
|
615
|
+
nk_vi32x4_t sum_lane1_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_i32x4,
|
|
616
|
+
(nk_vu64x2_t)transpose_cd_low_i32x4, 3);
|
|
617
|
+
nk_vi32x4_t sum_lane2_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_i32x4,
|
|
618
|
+
(nk_vu64x2_t)transpose_cd_high_i32x4, 0);
|
|
619
|
+
nk_vi32x4_t sum_lane3_i32x4 = (nk_vi32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_i32x4,
|
|
620
|
+
(nk_vu64x2_t)transpose_cd_high_i32x4, 3);
|
|
621
|
+
nk_vi32x4_t biased_i32x4 = vec_add(vec_add(sum_lane0_i32x4, sum_lane1_i32x4),
|
|
622
|
+
vec_add(sum_lane2_i32x4, sum_lane3_i32x4));
|
|
623
|
+
|
|
624
|
+
// Correction: VMSUMMBM(b, a⊕0x80) = Σ(b_i·(a_i+128)) = a·b + 128·Σb
|
|
625
|
+
// So a·b = biased − 128·Σb. B column sums are precomputed during packing.
|
|
626
|
+
nk_vu32x4_t shift_u32x4 = vec_splats((nk_u32_t)7);
|
|
627
|
+
nk_vi32x4_t correction_i32x4 = (nk_vi32x4_t)vec_sl((nk_vu32x4_t)b_sums.vi32x4, shift_u32x4);
|
|
628
|
+
result->vi32x4 = vec_sub(biased_i32x4, correction_i32x4);
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
/** @brief Running state for i8 column sum precomputation on Power VSX. */
|
|
632
|
+
typedef struct nk_sum_i8x16_state_powervsx_t {
|
|
633
|
+
nk_vu32x4_t biased_sum_u32x4;
|
|
634
|
+
} nk_sum_i8x16_state_powervsx_t;
|
|
635
|
+
|
|
636
|
+
NK_INTERNAL void nk_sum_i8x16_init_powervsx(nk_sum_i8x16_state_powervsx_t *state) {
|
|
637
|
+
state->biased_sum_u32x4 = vec_splats((nk_u32_t)0);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
NK_INTERNAL void nk_sum_i8x16_update_powervsx(nk_sum_i8x16_state_powervsx_t *state, nk_b128_vec_t values_vec) {
|
|
641
|
+
nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
|
|
642
|
+
nk_vu8x16_t biased_u8x16 = vec_xor(values_vec.vu8x16, bias_u8x16);
|
|
643
|
+
state->biased_sum_u32x4 = vec_sum4s(biased_u8x16, state->biased_sum_u32x4);
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
NK_INTERNAL nk_i32_t nk_sum_i8x16_finalize_powervsx(nk_sum_i8x16_state_powervsx_t const *state, nk_size_t count) {
|
|
647
|
+
nk_u32_t biased_sum = nk_hsum_u32x4_powervsx_(state->biased_sum_u32x4);
|
|
648
|
+
return (nk_i32_t)((nk_i64_t)biased_sum - 128 * (nk_i64_t)count);
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
/**
|
|
652
|
+
* @brief Running state for 128-bit dot accumulation over u8 scalars on Power VSX.
|
|
653
|
+
*
|
|
654
|
+
* Processes 16 u8 values at a time via vec_msum, accumulating into 4 u32 lanes.
|
|
655
|
+
*/
|
|
656
|
+
typedef struct nk_dot_u8x16_state_powervsx_t {
|
|
657
|
+
nk_vu32x4_t sum_u32x4;
|
|
658
|
+
} nk_dot_u8x16_state_powervsx_t;
|
|
659
|
+
|
|
660
|
+
NK_INTERNAL void nk_dot_u8x16_init_powervsx(nk_dot_u8x16_state_powervsx_t *state) {
|
|
661
|
+
state->sum_u32x4 = vec_splats((nk_u32_t)0);
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
NK_INTERNAL void nk_dot_u8x16_update_powervsx(nk_dot_u8x16_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
665
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
666
|
+
nk_unused_(depth_offset);
|
|
667
|
+
nk_unused_(active_dimensions);
|
|
668
|
+
// Unsigned × unsigned multiply-sum: 16 u8 products accumulated into 4 u32 lanes
|
|
669
|
+
nk_vu8x16_t a_u8x16 = a.vu8x16;
|
|
670
|
+
nk_vu8x16_t b_u8x16 = b.vu8x16;
|
|
671
|
+
state->sum_u32x4 = vec_msum(a_u8x16, b_u8x16, state->sum_u32x4);
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
NK_INTERNAL void nk_dot_u8x16_finalize_powervsx( //
|
|
675
|
+
nk_dot_u8x16_state_powervsx_t const *state_a, nk_dot_u8x16_state_powervsx_t const *state_b, //
|
|
676
|
+
nk_dot_u8x16_state_powervsx_t const *state_c, nk_dot_u8x16_state_powervsx_t const *state_d, //
|
|
677
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
678
|
+
nk_unused_(total_dimensions);
|
|
679
|
+
nk_vu32x4_t a_u32x4 = state_a->sum_u32x4, b_u32x4 = state_b->sum_u32x4, c_u32x4 = state_c->sum_u32x4,
|
|
680
|
+
d_u32x4 = state_d->sum_u32x4;
|
|
681
|
+
nk_vu32x4_t transpose_ab_low_u32x4 = vec_mergeh(a_u32x4, b_u32x4);
|
|
682
|
+
nk_vu32x4_t transpose_cd_low_u32x4 = vec_mergeh(c_u32x4, d_u32x4);
|
|
683
|
+
nk_vu32x4_t transpose_ab_high_u32x4 = vec_mergel(a_u32x4, b_u32x4);
|
|
684
|
+
nk_vu32x4_t transpose_cd_high_u32x4 = vec_mergel(c_u32x4, d_u32x4);
|
|
685
|
+
nk_vu32x4_t sum_lane0_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_u32x4,
|
|
686
|
+
(nk_vu64x2_t)transpose_cd_low_u32x4, 0);
|
|
687
|
+
nk_vu32x4_t sum_lane1_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_low_u32x4,
|
|
688
|
+
(nk_vu64x2_t)transpose_cd_low_u32x4, 3);
|
|
689
|
+
nk_vu32x4_t sum_lane2_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_u32x4,
|
|
690
|
+
(nk_vu64x2_t)transpose_cd_high_u32x4, 0);
|
|
691
|
+
nk_vu32x4_t sum_lane3_u32x4 = (nk_vu32x4_t)vec_xxpermdi((nk_vu64x2_t)transpose_ab_high_u32x4,
|
|
692
|
+
(nk_vu64x2_t)transpose_cd_high_u32x4, 3);
|
|
693
|
+
result->vu32x4 = vec_add(vec_add(sum_lane0_u32x4, sum_lane1_u32x4), vec_add(sum_lane2_u32x4, sum_lane3_u32x4));
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
/**
|
|
697
|
+
* @brief Running state for 128-bit binary dot accumulation on Power VSX.
|
|
698
|
+
*
|
|
699
|
+
* Processes 128 bits (16 bytes) at a time via AND + doubleword popcount (vpopcntd),
|
|
700
|
+
* accumulating bit-match counts into 2 u64 lanes.
|
|
701
|
+
*/
|
|
702
|
+
typedef struct nk_dot_u1x128_state_powervsx_t {
|
|
703
|
+
nk_vu64x2_t dot_count_u64x2;
|
|
704
|
+
} nk_dot_u1x128_state_powervsx_t;
|
|
705
|
+
|
|
706
|
+
NK_INTERNAL void nk_dot_u1x128_init_powervsx(nk_dot_u1x128_state_powervsx_t *state) {
|
|
707
|
+
state->dot_count_u64x2 = vec_splats((nk_u64_t)0);
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
NK_INTERNAL void nk_dot_u1x128_update_powervsx(nk_dot_u1x128_state_powervsx_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
711
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
712
|
+
nk_unused_(depth_offset);
|
|
713
|
+
nk_unused_(active_dimensions);
|
|
714
|
+
// AND → doubleword popcount (vpopcntd, 3cy ALU) → vec_add (7cy DP)
|
|
715
|
+
// Simpler data flow than vpopcntb + vec_sum4s, and u64 accumulator holds larger counts
|
|
716
|
+
nk_vu8x16_t a_u8x16 = a.vu8x16;
|
|
717
|
+
nk_vu8x16_t b_u8x16 = b.vu8x16;
|
|
718
|
+
nk_vu8x16_t and_u8x16 = vec_and(a_u8x16, b_u8x16);
|
|
719
|
+
nk_vu64x2_t popcnt_u64x2 = vec_popcnt((nk_vu64x2_t)and_u8x16);
|
|
720
|
+
state->dot_count_u64x2 = vec_add(state->dot_count_u64x2, popcnt_u64x2);
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
NK_INTERNAL void nk_dot_u1x128_finalize_powervsx( //
|
|
724
|
+
nk_dot_u1x128_state_powervsx_t const *state_a, nk_dot_u1x128_state_powervsx_t const *state_b, //
|
|
725
|
+
nk_dot_u1x128_state_powervsx_t const *state_c, nk_dot_u1x128_state_powervsx_t const *state_d, //
|
|
726
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
727
|
+
nk_unused_(total_dimensions);
|
|
728
|
+
nk_vu64x2_t sum_a_u64x2 = vec_add(state_a->dot_count_u64x2,
|
|
729
|
+
vec_xxpermdi(state_a->dot_count_u64x2, state_a->dot_count_u64x2, 2));
|
|
730
|
+
nk_vu64x2_t sum_b_u64x2 = vec_add(state_b->dot_count_u64x2,
|
|
731
|
+
vec_xxpermdi(state_b->dot_count_u64x2, state_b->dot_count_u64x2, 2));
|
|
732
|
+
nk_vu64x2_t sum_c_u64x2 = vec_add(state_c->dot_count_u64x2,
|
|
733
|
+
vec_xxpermdi(state_c->dot_count_u64x2, state_c->dot_count_u64x2, 2));
|
|
734
|
+
nk_vu64x2_t sum_d_u64x2 = vec_add(state_d->dot_count_u64x2,
|
|
735
|
+
vec_xxpermdi(state_d->dot_count_u64x2, state_d->dot_count_u64x2, 2));
|
|
736
|
+
nk_vu64x2_t ab_u64x2 = vec_xxpermdi(sum_a_u64x2, sum_b_u64x2, 0);
|
|
737
|
+
nk_vu64x2_t cd_u64x2 = vec_xxpermdi(sum_c_u64x2, sum_d_u64x2, 0);
|
|
738
|
+
result->vu32x4 = vec_pack(ab_u64x2, cd_u64x2);
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
#if defined(__clang__)
|
|
742
|
+
#pragma clang attribute pop
|
|
743
|
+
#elif defined(__GNUC__)
|
|
744
|
+
#pragma GCC pop_options
|
|
745
|
+
#endif
|
|
746
|
+
|
|
747
|
+
#if defined(__cplusplus)
|
|
748
|
+
} // extern "C"
|
|
749
|
+
#endif
|
|
750
|
+
|
|
751
|
+
#endif // NK_TARGET_POWERVSX
|
|
752
|
+
#endif // NK_DOT_POWERVSX_H
|