numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Power VSX.
|
|
3
|
+
* @file include/numkong/spatial/powervsx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_powervsx_instructions Key Power VSX Spatial Instructions
|
|
10
|
+
*
|
|
11
|
+
* Power ISA 3.0 (POWER9+) VSX instructions for distance computations:
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction POWER9
|
|
14
|
+
* vec_madd(f32) XVMADDASP 5cy
|
|
15
|
+
* vec_mul(f32) XVMULSP 5cy
|
|
16
|
+
* vec_add(f32) XVADDSP 5cy
|
|
17
|
+
* vec_sub(f32) XVSUBSP 5cy
|
|
18
|
+
* vec_rsqrte(f32) XVRSQRTESP 5cy
|
|
19
|
+
* vec_sqrt(f32) XVSQRTSP 26cy
|
|
20
|
+
* vec_doublee XVCVSPDP 3cy (f32 → f64 even elts)
|
|
21
|
+
* vec_xl_len LXVL 5cy (partial vector load)
|
|
22
|
+
* vec_extract_fp32_from_shorth XVCVHPSP 5cy (f16 → f32 high half)
|
|
23
|
+
* vec_extract_fp32_from_shortl XVCVHPSP 5cy (f16 → f32 low half)
|
|
24
|
+
* vec_msum(i8, u8, i32) VMSUMMBM 5cy (i8×u8 widening multiply-sum)
|
|
25
|
+
* vec_msum(u8, u8, u32) VMSUMUBM 5cy (u8×u8 widening multiply-sum)
|
|
26
|
+
* vec_unpackh(i8) VUPKHSB 2cy (sign-extend high 8 i8 → i16x8)
|
|
27
|
+
* vec_unpackl(i8) VUPKLSB 2cy (sign-extend low 8 i8 → i16x8)
|
|
28
|
+
*
|
|
29
|
+
* For angular distance, `vec_rsqrte` provides ~12-bit precision. Two Newton-Raphson
|
|
30
|
+
* iterations achieve ~23-bit precision for f32, three iterations for f64.
|
|
31
|
+
*/
|
|
32
|
+
#ifndef NK_SPATIAL_POWERVSX_H
|
|
33
|
+
#define NK_SPATIAL_POWERVSX_H
|
|
34
|
+
|
|
35
|
+
#if NK_TARGET_POWER_
|
|
36
|
+
#if NK_TARGET_POWERVSX
|
|
37
|
+
|
|
38
|
+
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/dot/powervsx.h" // `nk_hsum_*_powervsx_`, includes cast/powervsx.h
|
|
40
|
+
#include "numkong/scalar/powervsx.h" // `nk_f32_sqrt_powervsx`, `nk_f64_sqrt_powervsx`
|
|
41
|
+
|
|
42
|
+
#if defined(__cplusplus)
|
|
43
|
+
extern "C" {
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
#if defined(__clang__)
|
|
47
|
+
#pragma clang attribute push(__attribute__((target("power9-vector"))), apply_to = function)
|
|
48
|
+
#elif defined(__GNUC__)
|
|
49
|
+
#pragma GCC push_options
|
|
50
|
+
#pragma GCC target("power9-vector")
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
/**
|
|
54
|
+
* @brief Reciprocal square root of 4 floats with Newton-Raphson refinement.
|
|
55
|
+
*
|
|
56
|
+
* Uses `vec_rsqrte` (~12-bit initial estimate) followed by two Newton-Raphson
|
|
57
|
+
* iterations, achieving ~23-bit precision sufficient for f32.
|
|
58
|
+
* NR step: rsqrt = rsqrt × (1.5 − 0.5 × x × rsqrt × rsqrt)
|
|
59
|
+
*/
|
|
60
|
+
NK_INTERNAL nk_vf32x4_t nk_rsqrt_f32x4_powervsx_(nk_vf32x4_t x) {
|
|
61
|
+
nk_vf32x4_t half_f32x4 = vec_splats(0.5f);
|
|
62
|
+
nk_vf32x4_t three_halves_f32x4 = vec_splats(1.5f);
|
|
63
|
+
nk_vf32x4_t rsqrt_f32x4 = vec_rsqrte(x);
|
|
64
|
+
// Iteration 1
|
|
65
|
+
nk_vf32x4_t nr_f32x4 = vec_sub(three_halves_f32x4,
|
|
66
|
+
vec_mul(half_f32x4, vec_mul(x, vec_mul(rsqrt_f32x4, rsqrt_f32x4))));
|
|
67
|
+
rsqrt_f32x4 = vec_mul(rsqrt_f32x4, nr_f32x4);
|
|
68
|
+
// Iteration 2
|
|
69
|
+
nr_f32x4 = vec_sub(three_halves_f32x4, vec_mul(half_f32x4, vec_mul(x, vec_mul(rsqrt_f32x4, rsqrt_f32x4))));
|
|
70
|
+
rsqrt_f32x4 = vec_mul(rsqrt_f32x4, nr_f32x4);
|
|
71
|
+
return rsqrt_f32x4;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
/**
|
|
75
|
+
* @brief Reciprocal square root of 2 doubles with Newton-Raphson refinement.
|
|
76
|
+
*
|
|
77
|
+
* Uses `vec_rsqrte` (~12-bit estimate) followed by three Newton-Raphson
|
|
78
|
+
* iterations, achieving ~48-bit precision for f64.
|
|
79
|
+
*/
|
|
80
|
+
NK_INTERNAL nk_vf64x2_t nk_rsqrt_f64x2_powervsx_(nk_vf64x2_t x) {
|
|
81
|
+
nk_vf64x2_t half_f64x2 = vec_splats(0.5);
|
|
82
|
+
nk_vf64x2_t three_halves_f64x2 = vec_splats(1.5);
|
|
83
|
+
nk_vf64x2_t rsqrt_f64x2 = vec_rsqrte(x);
|
|
84
|
+
// Iteration 1
|
|
85
|
+
nk_vf64x2_t nr_f64x2 = vec_sub(three_halves_f64x2,
|
|
86
|
+
vec_mul(half_f64x2, vec_mul(x, vec_mul(rsqrt_f64x2, rsqrt_f64x2))));
|
|
87
|
+
rsqrt_f64x2 = vec_mul(rsqrt_f64x2, nr_f64x2);
|
|
88
|
+
// Iteration 2
|
|
89
|
+
nr_f64x2 = vec_sub(three_halves_f64x2, vec_mul(half_f64x2, vec_mul(x, vec_mul(rsqrt_f64x2, rsqrt_f64x2))));
|
|
90
|
+
rsqrt_f64x2 = vec_mul(rsqrt_f64x2, nr_f64x2);
|
|
91
|
+
// Iteration 3
|
|
92
|
+
nr_f64x2 = vec_sub(three_halves_f64x2, vec_mul(half_f64x2, vec_mul(x, vec_mul(rsqrt_f64x2, rsqrt_f64x2))));
|
|
93
|
+
rsqrt_f64x2 = vec_mul(rsqrt_f64x2, nr_f64x2);
|
|
94
|
+
return rsqrt_f64x2;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
NK_INTERNAL nk_f32_t nk_angular_normalize_f32_powervsx_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
|
|
98
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
99
|
+
if (ab == 0) return 1;
|
|
100
|
+
nk_vf32x4_t squares_f32x4 = vec_splats(0.0f);
|
|
101
|
+
squares_f32x4 = vec_insert(a2, squares_f32x4, 0);
|
|
102
|
+
squares_f32x4 = vec_insert(b2, squares_f32x4, 1);
|
|
103
|
+
nk_vf32x4_t rsqrts_f32x4 = nk_rsqrt_f32x4_powervsx_(squares_f32x4);
|
|
104
|
+
nk_f32_t a2_rsqrt = vec_extract(rsqrts_f32x4, 0);
|
|
105
|
+
nk_f32_t b2_rsqrt = vec_extract(rsqrts_f32x4, 1);
|
|
106
|
+
nk_f32_t result = 1 - ab * a2_rsqrt * b2_rsqrt;
|
|
107
|
+
return result > 0 ? result : 0;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_powervsx_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
111
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
112
|
+
if (ab == 0) return 1;
|
|
113
|
+
nk_vf64x2_t squares_f64x2 = vec_splats(0.0);
|
|
114
|
+
squares_f64x2 = vec_insert(a2, squares_f64x2, 0);
|
|
115
|
+
squares_f64x2 = vec_insert(b2, squares_f64x2, 1);
|
|
116
|
+
nk_vf64x2_t rsqrts_f64x2 = nk_rsqrt_f64x2_powervsx_(squares_f64x2);
|
|
117
|
+
nk_f64_t a2_rsqrt = vec_extract(rsqrts_f64x2, 0);
|
|
118
|
+
nk_f64_t b2_rsqrt = vec_extract(rsqrts_f64x2, 1);
|
|
119
|
+
nk_f64_t result = 1 - ab * a2_rsqrt * b2_rsqrt;
|
|
120
|
+
return result > 0 ? result : 0;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
#pragma region F32 and F64 Floats
|
|
124
|
+
|
|
125
|
+
NK_PUBLIC void nk_sqeuclidean_f32_powervsx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
126
|
+
// Accumulate in f64 for numerical stability using vec_doublee/vec_doubleo (f32 → f64)
|
|
127
|
+
nk_vf64x2_t sum_even_f64x2 = vec_splats((nk_f64_t)0);
|
|
128
|
+
nk_vf64x2_t sum_odd_f64x2 = vec_splats((nk_f64_t)0);
|
|
129
|
+
nk_vf32x4_t a_f32x4, b_f32x4;
|
|
130
|
+
nk_size_t tail_bytes;
|
|
131
|
+
|
|
132
|
+
nk_sqeuclidean_f32_powervsx_cycle:
|
|
133
|
+
if (n < 4) {
|
|
134
|
+
tail_bytes = n * sizeof(nk_f32_t);
|
|
135
|
+
a_f32x4 = vec_xl_len((nk_f32_t *)a, tail_bytes);
|
|
136
|
+
b_f32x4 = vec_xl_len((nk_f32_t *)b, tail_bytes);
|
|
137
|
+
n = 0;
|
|
138
|
+
}
|
|
139
|
+
else {
|
|
140
|
+
a_f32x4 = vec_xl(0, a);
|
|
141
|
+
b_f32x4 = vec_xl(0, b);
|
|
142
|
+
a += 4, b += 4, n -= 4;
|
|
143
|
+
}
|
|
144
|
+
// Widen a and b to f64 before subtraction to avoid f32 precision loss in (a−b)
|
|
145
|
+
nk_vf64x2_t a_even_f64x2 = vec_doublee(a_f32x4);
|
|
146
|
+
nk_vf64x2_t b_even_f64x2 = vec_doublee(b_f32x4);
|
|
147
|
+
nk_vf64x2_t diff_even_f64x2 = vec_sub(a_even_f64x2, b_even_f64x2);
|
|
148
|
+
sum_even_f64x2 = vec_madd(diff_even_f64x2, diff_even_f64x2, sum_even_f64x2);
|
|
149
|
+
nk_vf64x2_t a_odd_f64x2 = vec_doubleo(a_f32x4);
|
|
150
|
+
nk_vf64x2_t b_odd_f64x2 = vec_doubleo(b_f32x4);
|
|
151
|
+
nk_vf64x2_t diff_odd_f64x2 = vec_sub(a_odd_f64x2, b_odd_f64x2);
|
|
152
|
+
sum_odd_f64x2 = vec_madd(diff_odd_f64x2, diff_odd_f64x2, sum_odd_f64x2);
|
|
153
|
+
if (n) goto nk_sqeuclidean_f32_powervsx_cycle;
|
|
154
|
+
|
|
155
|
+
nk_vf64x2_t total_f64x2 = vec_add(sum_even_f64x2, sum_odd_f64x2);
|
|
156
|
+
*result = nk_hsum_f64x2_powervsx_(total_f64x2);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
NK_PUBLIC void nk_euclidean_f32_powervsx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
160
|
+
nk_sqeuclidean_f32_powervsx(a, b, n, result);
|
|
161
|
+
*result = nk_f64_sqrt_powervsx(*result);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
NK_PUBLIC void nk_angular_f32_powervsx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
165
|
+
// Accumulate in f64 for numerical stability using vec_doublee/vec_doubleo
|
|
166
|
+
nk_vf64x2_t ab_even_f64x2 = vec_splats((nk_f64_t)0);
|
|
167
|
+
nk_vf64x2_t ab_odd_f64x2 = vec_splats((nk_f64_t)0);
|
|
168
|
+
nk_vf64x2_t a2_even_f64x2 = vec_splats((nk_f64_t)0);
|
|
169
|
+
nk_vf64x2_t a2_odd_f64x2 = vec_splats((nk_f64_t)0);
|
|
170
|
+
nk_vf64x2_t b2_even_f64x2 = vec_splats((nk_f64_t)0);
|
|
171
|
+
nk_vf64x2_t b2_odd_f64x2 = vec_splats((nk_f64_t)0);
|
|
172
|
+
nk_vf32x4_t a_f32x4, b_f32x4;
|
|
173
|
+
nk_size_t tail_bytes;
|
|
174
|
+
|
|
175
|
+
nk_angular_f32_powervsx_cycle:
|
|
176
|
+
if (n < 4) {
|
|
177
|
+
tail_bytes = n * sizeof(nk_f32_t);
|
|
178
|
+
a_f32x4 = vec_xl_len((nk_f32_t *)a, tail_bytes);
|
|
179
|
+
b_f32x4 = vec_xl_len((nk_f32_t *)b, tail_bytes);
|
|
180
|
+
n = 0;
|
|
181
|
+
}
|
|
182
|
+
else {
|
|
183
|
+
a_f32x4 = vec_xl(0, a);
|
|
184
|
+
b_f32x4 = vec_xl(0, b);
|
|
185
|
+
a += 4, b += 4, n -= 4;
|
|
186
|
+
}
|
|
187
|
+
// Even elements (0, 2) → f64
|
|
188
|
+
nk_vf64x2_t a_even_f64x2 = vec_doublee(a_f32x4);
|
|
189
|
+
nk_vf64x2_t b_even_f64x2 = vec_doublee(b_f32x4);
|
|
190
|
+
ab_even_f64x2 = vec_madd(a_even_f64x2, b_even_f64x2, ab_even_f64x2);
|
|
191
|
+
a2_even_f64x2 = vec_madd(a_even_f64x2, a_even_f64x2, a2_even_f64x2);
|
|
192
|
+
b2_even_f64x2 = vec_madd(b_even_f64x2, b_even_f64x2, b2_even_f64x2);
|
|
193
|
+
// Odd elements (1, 3) → f64: rotate by 4 bytes
|
|
194
|
+
nk_vf32x4_t a_rotated_f32x4 = vec_sld(a_f32x4, a_f32x4, 4);
|
|
195
|
+
nk_vf32x4_t b_rotated_f32x4 = vec_sld(b_f32x4, b_f32x4, 4);
|
|
196
|
+
nk_vf64x2_t a_odd_f64x2 = vec_doublee(a_rotated_f32x4);
|
|
197
|
+
nk_vf64x2_t b_odd_f64x2 = vec_doublee(b_rotated_f32x4);
|
|
198
|
+
ab_odd_f64x2 = vec_madd(a_odd_f64x2, b_odd_f64x2, ab_odd_f64x2);
|
|
199
|
+
a2_odd_f64x2 = vec_madd(a_odd_f64x2, a_odd_f64x2, a2_odd_f64x2);
|
|
200
|
+
b2_odd_f64x2 = vec_madd(b_odd_f64x2, b_odd_f64x2, b2_odd_f64x2);
|
|
201
|
+
if (n) goto nk_angular_f32_powervsx_cycle;
|
|
202
|
+
|
|
203
|
+
nk_f64_t ab = nk_hsum_f64x2_powervsx_(vec_add(ab_even_f64x2, ab_odd_f64x2));
|
|
204
|
+
nk_f64_t a2 = nk_hsum_f64x2_powervsx_(vec_add(a2_even_f64x2, a2_odd_f64x2));
|
|
205
|
+
nk_f64_t b2 = nk_hsum_f64x2_powervsx_(vec_add(b2_even_f64x2, b2_odd_f64x2));
|
|
206
|
+
*result = nk_angular_normalize_f64_powervsx_(ab, a2, b2);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
NK_PUBLIC void nk_sqeuclidean_f64_powervsx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
210
|
+
nk_vf64x2_t sum_f64x2 = vec_splats((nk_f64_t)0);
|
|
211
|
+
nk_vf64x2_t a_f64x2, b_f64x2;
|
|
212
|
+
nk_size_t tail_bytes;
|
|
213
|
+
|
|
214
|
+
nk_sqeuclidean_f64_powervsx_cycle:
|
|
215
|
+
if (n < 2) {
|
|
216
|
+
tail_bytes = n * sizeof(nk_f64_t);
|
|
217
|
+
a_f64x2 = vec_xl_len((nk_f64_t *)a, tail_bytes);
|
|
218
|
+
b_f64x2 = vec_xl_len((nk_f64_t *)b, tail_bytes);
|
|
219
|
+
n = 0;
|
|
220
|
+
}
|
|
221
|
+
else {
|
|
222
|
+
a_f64x2 = vec_xl(0, a);
|
|
223
|
+
b_f64x2 = vec_xl(0, b);
|
|
224
|
+
a += 2, b += 2, n -= 2;
|
|
225
|
+
}
|
|
226
|
+
nk_vf64x2_t diff_f64x2 = vec_sub(a_f64x2, b_f64x2);
|
|
227
|
+
sum_f64x2 = vec_madd(diff_f64x2, diff_f64x2, sum_f64x2);
|
|
228
|
+
if (n) goto nk_sqeuclidean_f64_powervsx_cycle;
|
|
229
|
+
|
|
230
|
+
*result = nk_hsum_f64x2_powervsx_(sum_f64x2);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
NK_PUBLIC void nk_euclidean_f64_powervsx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
234
|
+
nk_sqeuclidean_f64_powervsx(a, b, n, result);
|
|
235
|
+
*result = nk_f64_sqrt_powervsx(*result);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
NK_PUBLIC void nk_angular_f64_powervsx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
239
|
+
// Dot2 (Ogita-Rump-Oishi) for cross-product ab (may have cancellation),
|
|
240
|
+
// simple FMA for self-products a2/b2 (all positive, no cancellation)
|
|
241
|
+
nk_vf64x2_t ab_sum_f64x2 = vec_splats((nk_f64_t)0);
|
|
242
|
+
nk_vf64x2_t ab_compensation_f64x2 = vec_splats((nk_f64_t)0);
|
|
243
|
+
nk_vf64x2_t a2_f64x2 = vec_splats((nk_f64_t)0);
|
|
244
|
+
nk_vf64x2_t b2_f64x2 = vec_splats((nk_f64_t)0);
|
|
245
|
+
nk_vf64x2_t a_f64x2, b_f64x2;
|
|
246
|
+
nk_size_t tail_bytes;
|
|
247
|
+
|
|
248
|
+
nk_angular_f64_powervsx_cycle:
|
|
249
|
+
if (n < 2) {
|
|
250
|
+
tail_bytes = n * sizeof(nk_f64_t);
|
|
251
|
+
a_f64x2 = vec_xl_len((nk_f64_t *)a, tail_bytes);
|
|
252
|
+
b_f64x2 = vec_xl_len((nk_f64_t *)b, tail_bytes);
|
|
253
|
+
n = 0;
|
|
254
|
+
}
|
|
255
|
+
else {
|
|
256
|
+
a_f64x2 = vec_xl(0, a);
|
|
257
|
+
b_f64x2 = vec_xl(0, b);
|
|
258
|
+
a += 2, b += 2, n -= 2;
|
|
259
|
+
}
|
|
260
|
+
// TwoProd for ab: product = a×b, error = msub(a, b, product) captures rounding error
|
|
261
|
+
nk_vf64x2_t product_f64x2 = vec_mul(a_f64x2, b_f64x2);
|
|
262
|
+
nk_vf64x2_t product_error_f64x2 = vec_msub(a_f64x2, b_f64x2, product_f64x2);
|
|
263
|
+
// TwoSum: (t, q) = TwoSum(sum, product) where t = sum + product rounded, q = error
|
|
264
|
+
nk_vf64x2_t tentative_sum_f64x2 = vec_add(ab_sum_f64x2, product_f64x2);
|
|
265
|
+
nk_vf64x2_t virtual_addend_f64x2 = vec_sub(tentative_sum_f64x2, ab_sum_f64x2);
|
|
266
|
+
nk_vf64x2_t sum_error_f64x2 = vec_add(vec_sub(ab_sum_f64x2, vec_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
267
|
+
vec_sub(product_f64x2, virtual_addend_f64x2));
|
|
268
|
+
ab_sum_f64x2 = tentative_sum_f64x2;
|
|
269
|
+
ab_compensation_f64x2 = vec_add(ab_compensation_f64x2, vec_add(sum_error_f64x2, product_error_f64x2));
|
|
270
|
+
// Simple FMA for self-products (no cancellation)
|
|
271
|
+
a2_f64x2 = vec_madd(a_f64x2, a_f64x2, a2_f64x2);
|
|
272
|
+
b2_f64x2 = vec_madd(b_f64x2, b_f64x2, b2_f64x2);
|
|
273
|
+
if (n) goto nk_angular_f64_powervsx_cycle;
|
|
274
|
+
|
|
275
|
+
*result = nk_angular_normalize_f64_powervsx_(nk_dot_stable_sum_f64x2_powervsx_(ab_sum_f64x2, ab_compensation_f64x2),
|
|
276
|
+
nk_hsum_f64x2_powervsx_(a2_f64x2), nk_hsum_f64x2_powervsx_(b2_f64x2));
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
#pragma endregion F32 and F64 Floats
|
|
280
|
+
#pragma region F16 and BF16 Floats
|
|
281
|
+
|
|
282
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_powervsx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
283
|
+
// bf16 → f32 via merge with zero: places bf16 bits in upper 16 of each f32
|
|
284
|
+
nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
|
|
285
|
+
nk_vf32x4_t sum_f32x4 = vec_splats(0.0f);
|
|
286
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
287
|
+
nk_size_t tail_bytes;
|
|
288
|
+
|
|
289
|
+
nk_sqeuclidean_bf16_powervsx_cycle:
|
|
290
|
+
if (n < 8) {
|
|
291
|
+
tail_bytes = n * sizeof(nk_bf16_t);
|
|
292
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a, tail_bytes);
|
|
293
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b, tail_bytes);
|
|
294
|
+
n = 0;
|
|
295
|
+
}
|
|
296
|
+
else {
|
|
297
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a);
|
|
298
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b);
|
|
299
|
+
a += 8, b += 8, n -= 8;
|
|
300
|
+
}
|
|
301
|
+
nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
|
|
302
|
+
nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
|
|
303
|
+
nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
|
|
304
|
+
nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
|
|
305
|
+
nk_vf32x4_t diff_high_f32x4 = vec_sub(a_high_f32x4, b_high_f32x4);
|
|
306
|
+
nk_vf32x4_t diff_low_f32x4 = vec_sub(a_low_f32x4, b_low_f32x4);
|
|
307
|
+
sum_f32x4 = vec_madd(diff_high_f32x4, diff_high_f32x4, sum_f32x4);
|
|
308
|
+
sum_f32x4 = vec_madd(diff_low_f32x4, diff_low_f32x4, sum_f32x4);
|
|
309
|
+
if (n) goto nk_sqeuclidean_bf16_powervsx_cycle;
|
|
310
|
+
*result = nk_hsum_f32x4_powervsx_(sum_f32x4);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
NK_PUBLIC void nk_euclidean_bf16_powervsx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
314
|
+
nk_sqeuclidean_bf16_powervsx(a, b, n, result);
|
|
315
|
+
*result = nk_f32_sqrt_powervsx(*result);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
NK_PUBLIC void nk_angular_bf16_powervsx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
319
|
+
nk_vu16x8_t zero_u16x8 = vec_splats((nk_u16_t)0);
|
|
320
|
+
nk_vf32x4_t ab_f32x4 = vec_splats(0.0f);
|
|
321
|
+
nk_vf32x4_t a2_f32x4 = vec_splats(0.0f);
|
|
322
|
+
nk_vf32x4_t b2_f32x4 = vec_splats(0.0f);
|
|
323
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
324
|
+
nk_size_t tail_bytes;
|
|
325
|
+
|
|
326
|
+
nk_angular_bf16_powervsx_cycle:
|
|
327
|
+
if (n < 8) {
|
|
328
|
+
tail_bytes = n * sizeof(nk_bf16_t);
|
|
329
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a, tail_bytes);
|
|
330
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b, tail_bytes);
|
|
331
|
+
n = 0;
|
|
332
|
+
}
|
|
333
|
+
else {
|
|
334
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a);
|
|
335
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b);
|
|
336
|
+
a += 8, b += 8, n -= 8;
|
|
337
|
+
}
|
|
338
|
+
nk_vf32x4_t a_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, a_u16x8);
|
|
339
|
+
nk_vf32x4_t a_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, a_u16x8);
|
|
340
|
+
nk_vf32x4_t b_high_f32x4 = (nk_vf32x4_t)vec_mergeh(zero_u16x8, b_u16x8);
|
|
341
|
+
nk_vf32x4_t b_low_f32x4 = (nk_vf32x4_t)vec_mergel(zero_u16x8, b_u16x8);
|
|
342
|
+
ab_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, ab_f32x4);
|
|
343
|
+
ab_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, ab_f32x4);
|
|
344
|
+
a2_f32x4 = vec_madd(a_high_f32x4, a_high_f32x4, a2_f32x4);
|
|
345
|
+
a2_f32x4 = vec_madd(a_low_f32x4, a_low_f32x4, a2_f32x4);
|
|
346
|
+
b2_f32x4 = vec_madd(b_high_f32x4, b_high_f32x4, b2_f32x4);
|
|
347
|
+
b2_f32x4 = vec_madd(b_low_f32x4, b_low_f32x4, b2_f32x4);
|
|
348
|
+
if (n) goto nk_angular_bf16_powervsx_cycle;
|
|
349
|
+
nk_f32_t ab = nk_hsum_f32x4_powervsx_(ab_f32x4);
|
|
350
|
+
nk_f32_t a2 = nk_hsum_f32x4_powervsx_(a2_f32x4);
|
|
351
|
+
nk_f32_t b2 = nk_hsum_f32x4_powervsx_(b2_f32x4);
|
|
352
|
+
*result = nk_angular_normalize_f32_powervsx_(ab, a2, b2);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
NK_PUBLIC void nk_sqeuclidean_f16_powervsx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
356
|
+
// f16 → f32 via POWER9 hardware XVCVHPSP (vec_extract_fp32_from_shorth/shortl)
|
|
357
|
+
nk_vf32x4_t sum_f32x4 = vec_splats(0.0f);
|
|
358
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
359
|
+
nk_size_t tail_bytes;
|
|
360
|
+
|
|
361
|
+
nk_sqeuclidean_f16_powervsx_cycle:
|
|
362
|
+
if (n < 8) {
|
|
363
|
+
tail_bytes = n * sizeof(nk_f16_t);
|
|
364
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a, tail_bytes);
|
|
365
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b, tail_bytes);
|
|
366
|
+
n = 0;
|
|
367
|
+
}
|
|
368
|
+
else {
|
|
369
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a);
|
|
370
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b);
|
|
371
|
+
a += 8, b += 8, n -= 8;
|
|
372
|
+
}
|
|
373
|
+
nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
|
|
374
|
+
nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
|
|
375
|
+
nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
|
|
376
|
+
nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
|
|
377
|
+
nk_vf32x4_t diff_high_f32x4 = vec_sub(a_high_f32x4, b_high_f32x4);
|
|
378
|
+
nk_vf32x4_t diff_low_f32x4 = vec_sub(a_low_f32x4, b_low_f32x4);
|
|
379
|
+
sum_f32x4 = vec_madd(diff_high_f32x4, diff_high_f32x4, sum_f32x4);
|
|
380
|
+
sum_f32x4 = vec_madd(diff_low_f32x4, diff_low_f32x4, sum_f32x4);
|
|
381
|
+
if (n) goto nk_sqeuclidean_f16_powervsx_cycle;
|
|
382
|
+
*result = nk_hsum_f32x4_powervsx_(sum_f32x4);
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
NK_PUBLIC void nk_euclidean_f16_powervsx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
386
|
+
nk_sqeuclidean_f16_powervsx(a, b, n, result);
|
|
387
|
+
*result = nk_f32_sqrt_powervsx(*result);
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
NK_PUBLIC void nk_angular_f16_powervsx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
391
|
+
// f16 → f32 via POWER9 hardware XVCVHPSP
|
|
392
|
+
nk_vf32x4_t ab_f32x4 = vec_splats(0.0f);
|
|
393
|
+
nk_vf32x4_t a2_f32x4 = vec_splats(0.0f);
|
|
394
|
+
nk_vf32x4_t b2_f32x4 = vec_splats(0.0f);
|
|
395
|
+
nk_vu16x8_t a_u16x8, b_u16x8;
|
|
396
|
+
nk_size_t tail_bytes;
|
|
397
|
+
|
|
398
|
+
nk_angular_f16_powervsx_cycle:
|
|
399
|
+
if (n < 8) {
|
|
400
|
+
tail_bytes = n * sizeof(nk_f16_t);
|
|
401
|
+
a_u16x8 = vec_xl_len((nk_u16_t *)a, tail_bytes);
|
|
402
|
+
b_u16x8 = vec_xl_len((nk_u16_t *)b, tail_bytes);
|
|
403
|
+
n = 0;
|
|
404
|
+
}
|
|
405
|
+
else {
|
|
406
|
+
a_u16x8 = vec_xl(0, (nk_u16_t const *)a);
|
|
407
|
+
b_u16x8 = vec_xl(0, (nk_u16_t const *)b);
|
|
408
|
+
a += 8, b += 8, n -= 8;
|
|
409
|
+
}
|
|
410
|
+
nk_vf32x4_t a_high_f32x4 = vec_extract_fp32_from_shorth(a_u16x8);
|
|
411
|
+
nk_vf32x4_t a_low_f32x4 = vec_extract_fp32_from_shortl(a_u16x8);
|
|
412
|
+
nk_vf32x4_t b_high_f32x4 = vec_extract_fp32_from_shorth(b_u16x8);
|
|
413
|
+
nk_vf32x4_t b_low_f32x4 = vec_extract_fp32_from_shortl(b_u16x8);
|
|
414
|
+
ab_f32x4 = vec_madd(a_high_f32x4, b_high_f32x4, ab_f32x4);
|
|
415
|
+
ab_f32x4 = vec_madd(a_low_f32x4, b_low_f32x4, ab_f32x4);
|
|
416
|
+
a2_f32x4 = vec_madd(a_high_f32x4, a_high_f32x4, a2_f32x4);
|
|
417
|
+
a2_f32x4 = vec_madd(a_low_f32x4, a_low_f32x4, a2_f32x4);
|
|
418
|
+
b2_f32x4 = vec_madd(b_high_f32x4, b_high_f32x4, b2_f32x4);
|
|
419
|
+
b2_f32x4 = vec_madd(b_low_f32x4, b_low_f32x4, b2_f32x4);
|
|
420
|
+
if (n) goto nk_angular_f16_powervsx_cycle;
|
|
421
|
+
nk_f32_t ab = nk_hsum_f32x4_powervsx_(ab_f32x4);
|
|
422
|
+
nk_f32_t a2 = nk_hsum_f32x4_powervsx_(a2_f32x4);
|
|
423
|
+
nk_f32_t b2 = nk_hsum_f32x4_powervsx_(b2_f32x4);
|
|
424
|
+
*result = nk_angular_normalize_f32_powervsx_(ab, a2, b2);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
#pragma endregion F16 and BF16 Floats
|
|
428
|
+
#pragma region I8 and U8 Integers
|
|
429
|
+
|
|
430
|
+
NK_PUBLIC void nk_sqeuclidean_i8_powervsx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
431
|
+
// Power has no vabdq_s8. Widen i8 → i16 via vec_unpackh/vec_unpackl,
|
|
432
|
+
// subtract in i16, then vec_msum(diff_i16, diff_i16, accumulator_i32) to square-accumulate.
|
|
433
|
+
nk_vi32x4_t accumulator_i32x4 = vec_splats((nk_i32_t)0);
|
|
434
|
+
nk_vi8x16_t a_i8x16, b_i8x16;
|
|
435
|
+
nk_size_t tail_bytes;
|
|
436
|
+
|
|
437
|
+
nk_sqeuclidean_i8_powervsx_cycle:
|
|
438
|
+
if (n < 16) {
|
|
439
|
+
tail_bytes = n * sizeof(nk_i8_t);
|
|
440
|
+
a_i8x16 = vec_xl_len((nk_i8_t *)a, tail_bytes);
|
|
441
|
+
b_i8x16 = vec_xl_len((nk_i8_t *)b, tail_bytes);
|
|
442
|
+
n = 0;
|
|
443
|
+
}
|
|
444
|
+
else {
|
|
445
|
+
a_i8x16 = vec_xl(0, a);
|
|
446
|
+
b_i8x16 = vec_xl(0, b);
|
|
447
|
+
a += 16, b += 16, n -= 16;
|
|
448
|
+
}
|
|
449
|
+
// Widen high 8 bytes: i8 → i16
|
|
450
|
+
nk_vi16x8_t a_high_i16x8 = vec_unpackh(a_i8x16);
|
|
451
|
+
nk_vi16x8_t b_high_i16x8 = vec_unpackh(b_i8x16);
|
|
452
|
+
nk_vi16x8_t diff_high_i16x8 = vec_sub(a_high_i16x8, b_high_i16x8);
|
|
453
|
+
// vec_msum: multiply 8 i16 pairs and accumulate into 4 i32 lanes
|
|
454
|
+
accumulator_i32x4 = vec_msum(diff_high_i16x8, diff_high_i16x8, accumulator_i32x4);
|
|
455
|
+
// Widen low 8 bytes: i8 → i16
|
|
456
|
+
nk_vi16x8_t a_low_i16x8 = vec_unpackl(a_i8x16);
|
|
457
|
+
nk_vi16x8_t b_low_i16x8 = vec_unpackl(b_i8x16);
|
|
458
|
+
nk_vi16x8_t diff_low_i16x8 = vec_sub(a_low_i16x8, b_low_i16x8);
|
|
459
|
+
accumulator_i32x4 = vec_msum(diff_low_i16x8, diff_low_i16x8, accumulator_i32x4);
|
|
460
|
+
if (n) goto nk_sqeuclidean_i8_powervsx_cycle;
|
|
461
|
+
|
|
462
|
+
*result = (nk_u32_t)nk_hsum_i32x4_powervsx_(accumulator_i32x4);
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
NK_PUBLIC void nk_euclidean_i8_powervsx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
466
|
+
nk_u32_t distance_sq_u32;
|
|
467
|
+
nk_sqeuclidean_i8_powervsx(a, b, n, &distance_sq_u32);
|
|
468
|
+
*result = nk_f32_sqrt_powervsx((nk_f32_t)distance_sq_u32);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
NK_PUBLIC void nk_angular_i8_powervsx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
472
|
+
// Hybrid approach for 3-accumulator i8 angular distance:
|
|
473
|
+
// a·b: algebraic transform — VMSUMMBM(a, b⊕0x80) with correction −128·Σa
|
|
474
|
+
// a·a: abs-based unsigned — VMSUMUBM(|a|, |a|), no correction needed
|
|
475
|
+
// b·b: abs-based unsigned — VMSUMUBM(|b|, |b|), no correction needed
|
|
476
|
+
// abs(-128)→-128 in i8 → 128 as u8 → 128²=16384=(-128)². Safe for all values.
|
|
477
|
+
// 3 independent MSUM chains → excellent ILP on POWER9's dual-issue p01.
|
|
478
|
+
nk_vu8x16_t const bias_u8x16 = vec_splats((nk_u8_t)0x80);
|
|
479
|
+
nk_vi8x16_t const zeros_i8x16 = vec_splats((nk_i8_t)0);
|
|
480
|
+
nk_vi32x4_t dot_product_i32x4 = vec_splats((nk_i32_t)0);
|
|
481
|
+
nk_vu32x4_t a_norm_sq_u32x4 = vec_splats((nk_u32_t)0);
|
|
482
|
+
nk_vu32x4_t b_norm_sq_u32x4 = vec_splats((nk_u32_t)0);
|
|
483
|
+
nk_vu32x4_t sum_a_biased_u32x4 = vec_splats((nk_u32_t)0);
|
|
484
|
+
nk_size_t count_padded = ((n + 15) / 16) * 16;
|
|
485
|
+
nk_vi8x16_t a_i8x16, b_i8x16;
|
|
486
|
+
nk_size_t tail_bytes;
|
|
487
|
+
|
|
488
|
+
nk_angular_i8_powervsx_cycle:
|
|
489
|
+
if (n < 16) {
|
|
490
|
+
tail_bytes = n * sizeof(nk_i8_t);
|
|
491
|
+
a_i8x16 = vec_xl_len((nk_i8_t *)a, tail_bytes);
|
|
492
|
+
b_i8x16 = vec_xl_len((nk_i8_t *)b, tail_bytes);
|
|
493
|
+
n = 0;
|
|
494
|
+
}
|
|
495
|
+
else {
|
|
496
|
+
a_i8x16 = vec_xl(0, a);
|
|
497
|
+
b_i8x16 = vec_xl(0, b);
|
|
498
|
+
a += 16, b += 16, n -= 16;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
// Dot product: algebraic via VMSUMMBM(i8 × u8 → i32)
|
|
502
|
+
nk_vu8x16_t b_biased_u8x16 = vec_xor((nk_vu8x16_t)b_i8x16, bias_u8x16);
|
|
503
|
+
dot_product_i32x4 = vec_msum(a_i8x16, b_biased_u8x16, dot_product_i32x4);
|
|
504
|
+
// Correction sum: Σ(a+128) via VSUM4UBS
|
|
505
|
+
sum_a_biased_u32x4 = vec_sum4s(vec_xor((nk_vu8x16_t)a_i8x16, bias_u8x16), sum_a_biased_u32x4);
|
|
506
|
+
// Norms: |a|² and |b|² via VMSUMUBM(u8 × u8 → u32) on absolute values
|
|
507
|
+
nk_vu8x16_t a_abs_u8x16 = (nk_vu8x16_t)vec_max(a_i8x16, vec_sub(zeros_i8x16, a_i8x16));
|
|
508
|
+
nk_vu8x16_t b_abs_u8x16 = (nk_vu8x16_t)vec_max(b_i8x16, vec_sub(zeros_i8x16, b_i8x16));
|
|
509
|
+
a_norm_sq_u32x4 = vec_msum(a_abs_u8x16, a_abs_u8x16, a_norm_sq_u32x4);
|
|
510
|
+
b_norm_sq_u32x4 = vec_msum(b_abs_u8x16, b_abs_u8x16, b_norm_sq_u32x4);
|
|
511
|
+
|
|
512
|
+
if (n) goto nk_angular_i8_powervsx_cycle;
|
|
513
|
+
|
|
514
|
+
// Correct the biased dot product: a·b = biased − 128·Σa = biased − 128·(Σ(a+128) − 128·count_padded)
|
|
515
|
+
nk_i64_t correction = 128LL * (nk_i64_t)nk_hsum_u32x4_powervsx_(sum_a_biased_u32x4) -
|
|
516
|
+
16384LL * (nk_i64_t)count_padded;
|
|
517
|
+
nk_i32_t dot_product_i32 = (nk_i32_t)((nk_i64_t)nk_hsum_i32x4_powervsx_(dot_product_i32x4) - correction);
|
|
518
|
+
nk_u32_t a_norm_sq_u32 = nk_hsum_u32x4_powervsx_(a_norm_sq_u32x4);
|
|
519
|
+
nk_u32_t b_norm_sq_u32 = nk_hsum_u32x4_powervsx_(b_norm_sq_u32x4);
|
|
520
|
+
*result = nk_angular_normalize_f32_powervsx_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_u32,
|
|
521
|
+
(nk_f32_t)b_norm_sq_u32);
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
NK_PUBLIC void nk_sqeuclidean_u8_powervsx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
525
|
+
// Compute |a-b| without underflow: vec_sub(vec_max(a, b), vec_min(a, b))
|
|
526
|
+
// Then square-accumulate via vec_msum(u8, u8, u32) → VMSUMUBM
|
|
527
|
+
nk_vu32x4_t accumulator_u32x4 = vec_splats((nk_u32_t)0);
|
|
528
|
+
nk_vu8x16_t a_u8x16, b_u8x16;
|
|
529
|
+
nk_size_t tail_bytes;
|
|
530
|
+
|
|
531
|
+
nk_sqeuclidean_u8_powervsx_cycle:
|
|
532
|
+
if (n < 16) {
|
|
533
|
+
tail_bytes = n * sizeof(nk_u8_t);
|
|
534
|
+
a_u8x16 = vec_xl_len((nk_u8_t *)a, tail_bytes);
|
|
535
|
+
b_u8x16 = vec_xl_len((nk_u8_t *)b, tail_bytes);
|
|
536
|
+
n = 0;
|
|
537
|
+
}
|
|
538
|
+
else {
|
|
539
|
+
a_u8x16 = vec_xl(0, a);
|
|
540
|
+
b_u8x16 = vec_xl(0, b);
|
|
541
|
+
a += 16, b += 16, n -= 16;
|
|
542
|
+
}
|
|
543
|
+
nk_vu8x16_t diff_u8x16 = vec_sub(vec_max(a_u8x16, b_u8x16), vec_min(a_u8x16, b_u8x16));
|
|
544
|
+
// VMSUMUBM: u8 × u8 → u32 accumulate
|
|
545
|
+
accumulator_u32x4 = vec_msum(diff_u8x16, diff_u8x16, accumulator_u32x4);
|
|
546
|
+
if (n) goto nk_sqeuclidean_u8_powervsx_cycle;
|
|
547
|
+
|
|
548
|
+
*result = nk_hsum_u32x4_powervsx_(accumulator_u32x4);
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
NK_PUBLIC void nk_euclidean_u8_powervsx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
552
|
+
nk_u32_t distance_sq_u32;
|
|
553
|
+
nk_sqeuclidean_u8_powervsx(a, b, n, &distance_sq_u32);
|
|
554
|
+
*result = nk_f32_sqrt_powervsx((nk_f32_t)distance_sq_u32);
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
NK_PUBLIC void nk_angular_u8_powervsx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
558
|
+
// Triple accumulator in u32 using vec_msum(u8, u8, u32) → VMSUMUBM
|
|
559
|
+
nk_vu32x4_t ab_u32x4 = vec_splats((nk_u32_t)0);
|
|
560
|
+
nk_vu32x4_t aa_u32x4 = vec_splats((nk_u32_t)0);
|
|
561
|
+
nk_vu32x4_t bb_u32x4 = vec_splats((nk_u32_t)0);
|
|
562
|
+
nk_vu8x16_t a_u8x16, b_u8x16;
|
|
563
|
+
nk_size_t tail_bytes;
|
|
564
|
+
|
|
565
|
+
nk_angular_u8_powervsx_cycle:
|
|
566
|
+
if (n < 16) {
|
|
567
|
+
tail_bytes = n * sizeof(nk_u8_t);
|
|
568
|
+
a_u8x16 = vec_xl_len((nk_u8_t *)a, tail_bytes);
|
|
569
|
+
b_u8x16 = vec_xl_len((nk_u8_t *)b, tail_bytes);
|
|
570
|
+
n = 0;
|
|
571
|
+
}
|
|
572
|
+
else {
|
|
573
|
+
a_u8x16 = vec_xl(0, a);
|
|
574
|
+
b_u8x16 = vec_xl(0, b);
|
|
575
|
+
a += 16, b += 16, n -= 16;
|
|
576
|
+
}
|
|
577
|
+
// VMSUMUBM: u8 × u8 → u32 accumulate
|
|
578
|
+
ab_u32x4 = vec_msum(a_u8x16, b_u8x16, ab_u32x4);
|
|
579
|
+
aa_u32x4 = vec_msum(a_u8x16, a_u8x16, aa_u32x4);
|
|
580
|
+
bb_u32x4 = vec_msum(b_u8x16, b_u8x16, bb_u32x4);
|
|
581
|
+
if (n) goto nk_angular_u8_powervsx_cycle;
|
|
582
|
+
|
|
583
|
+
nk_u32_t ab = nk_hsum_u32x4_powervsx_(ab_u32x4);
|
|
584
|
+
nk_u32_t aa = nk_hsum_u32x4_powervsx_(aa_u32x4);
|
|
585
|
+
nk_u32_t bb = nk_hsum_u32x4_powervsx_(bb_u32x4);
|
|
586
|
+
*result = nk_angular_normalize_f32_powervsx_((nk_f32_t)ab, (nk_f32_t)aa, (nk_f32_t)bb);
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs in f64. */
|
|
590
|
+
NK_INTERNAL void nk_angular_through_f64_from_dot_powervsx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
591
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
592
|
+
nk_vf64x2_t dots_ab_f64x2 = dots.vf64x2s[0];
|
|
593
|
+
nk_vf64x2_t dots_cd_f64x2 = dots.vf64x2s[1];
|
|
594
|
+
nk_vf64x2_t query_f64x2 = vec_splats(query_sumsq);
|
|
595
|
+
nk_vf64x2_t targets_ab_f64x2 = target_sumsqs.vf64x2s[0];
|
|
596
|
+
nk_vf64x2_t targets_cd_f64x2 = target_sumsqs.vf64x2s[1];
|
|
597
|
+
|
|
598
|
+
nk_vf64x2_t products_ab_f64x2 = vec_mul(query_f64x2, targets_ab_f64x2);
|
|
599
|
+
nk_vf64x2_t products_cd_f64x2 = vec_mul(query_f64x2, targets_cd_f64x2);
|
|
600
|
+
|
|
601
|
+
nk_vf64x2_t rsqrt_ab_f64x2 = nk_rsqrt_f64x2_powervsx_(products_ab_f64x2);
|
|
602
|
+
nk_vf64x2_t rsqrt_cd_f64x2 = nk_rsqrt_f64x2_powervsx_(products_cd_f64x2);
|
|
603
|
+
|
|
604
|
+
nk_vf64x2_t ones_f64x2 = vec_splats(1.0);
|
|
605
|
+
nk_vf64x2_t zeros_f64x2 = vec_splats(0.0);
|
|
606
|
+
nk_vf64x2_t result_ab_f64x2 = vec_max(vec_sub(ones_f64x2, vec_mul(dots_ab_f64x2, rsqrt_ab_f64x2)), zeros_f64x2);
|
|
607
|
+
nk_vf64x2_t result_cd_f64x2 = vec_max(vec_sub(ones_f64x2, vec_mul(dots_cd_f64x2, rsqrt_cd_f64x2)), zeros_f64x2);
|
|
608
|
+
|
|
609
|
+
nk_vu64x2_t prodzero_ab_u64x2 = (nk_vu64x2_t)vec_cmpeq(products_ab_f64x2, zeros_f64x2);
|
|
610
|
+
nk_vu64x2_t dotzero_ab_u64x2 = (nk_vu64x2_t)vec_cmpeq(dots_ab_f64x2, zeros_f64x2);
|
|
611
|
+
result_ab_f64x2 = vec_sel(result_ab_f64x2, zeros_f64x2, vec_and(prodzero_ab_u64x2, dotzero_ab_u64x2));
|
|
612
|
+
result_ab_f64x2 = vec_sel(result_ab_f64x2, ones_f64x2, vec_andc(prodzero_ab_u64x2, dotzero_ab_u64x2));
|
|
613
|
+
|
|
614
|
+
nk_vu64x2_t prodzero_cd_u64x2 = (nk_vu64x2_t)vec_cmpeq(products_cd_f64x2, zeros_f64x2);
|
|
615
|
+
nk_vu64x2_t dotzero_cd_u64x2 = (nk_vu64x2_t)vec_cmpeq(dots_cd_f64x2, zeros_f64x2);
|
|
616
|
+
result_cd_f64x2 = vec_sel(result_cd_f64x2, zeros_f64x2, vec_and(prodzero_cd_u64x2, dotzero_cd_u64x2));
|
|
617
|
+
result_cd_f64x2 = vec_sel(result_cd_f64x2, ones_f64x2, vec_andc(prodzero_cd_u64x2, dotzero_cd_u64x2));
|
|
618
|
+
|
|
619
|
+
results->vf64x2s[0] = result_ab_f64x2;
|
|
620
|
+
results->vf64x2s[1] = result_cd_f64x2;
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2×dot) for 4 pairs in f64. */
|
|
624
|
+
NK_INTERNAL void nk_euclidean_through_f64_from_dot_powervsx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
625
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
626
|
+
nk_vf64x2_t query_f64x2 = vec_splats(query_sumsq);
|
|
627
|
+
nk_vf64x2_t neg_two_f64x2 = vec_splats(-2.0);
|
|
628
|
+
nk_vf64x2_t zeros_f64x2 = vec_splats(0.0);
|
|
629
|
+
|
|
630
|
+
nk_vf64x2_t sum_sq_ab_f64x2 = vec_add(query_f64x2, target_sumsqs.vf64x2s[0]);
|
|
631
|
+
nk_vf64x2_t sum_sq_cd_f64x2 = vec_add(query_f64x2, target_sumsqs.vf64x2s[1]);
|
|
632
|
+
nk_vf64x2_t dist_sq_ab_f64x2 = vec_max(vec_madd(neg_two_f64x2, dots.vf64x2s[0], sum_sq_ab_f64x2), zeros_f64x2);
|
|
633
|
+
nk_vf64x2_t dist_sq_cd_f64x2 = vec_max(vec_madd(neg_two_f64x2, dots.vf64x2s[1], sum_sq_cd_f64x2), zeros_f64x2);
|
|
634
|
+
|
|
635
|
+
results->vf64x2s[0] = vec_sqrt(dist_sq_ab_f64x2);
|
|
636
|
+
results->vf64x2s[1] = vec_sqrt(dist_sq_cd_f64x2);
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs in f32. */
|
|
640
|
+
NK_INTERNAL void nk_angular_through_f32_from_dot_powervsx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
641
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
642
|
+
nk_vf32x4_t dots_f32x4 = dots.vf32x4;
|
|
643
|
+
nk_vf32x4_t query_f32x4 = vec_splats(query_sumsq);
|
|
644
|
+
nk_vf32x4_t products_f32x4 = vec_mul(query_f32x4, target_sumsqs.vf32x4);
|
|
645
|
+
nk_vf32x4_t rsqrt_f32x4 = nk_rsqrt_f32x4_powervsx_(products_f32x4);
|
|
646
|
+
nk_vf32x4_t normalized_f32x4 = vec_mul(dots_f32x4, rsqrt_f32x4);
|
|
647
|
+
nk_vf32x4_t angular_f32x4 = vec_sub(vec_splats(1.0f), normalized_f32x4);
|
|
648
|
+
nk_vf32x4_t result_f32x4 = vec_max(angular_f32x4, vec_splats(0.0f));
|
|
649
|
+
results->vf32x4 = result_f32x4;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2×dot) for 4 pairs in f32. */
|
|
653
|
+
NK_INTERNAL void nk_euclidean_through_f32_from_dot_powervsx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
654
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
655
|
+
nk_vf32x4_t dots_f32x4 = dots.vf32x4;
|
|
656
|
+
nk_vf32x4_t query_f32x4 = vec_splats(query_sumsq);
|
|
657
|
+
nk_vf32x4_t sum_sq_f32x4 = vec_add(query_f32x4, target_sumsqs.vf32x4);
|
|
658
|
+
// dist_sq = sum_sq − 2 × dot
|
|
659
|
+
nk_vf32x4_t dist_sq_f32x4 = vec_madd(vec_splats(-2.0f), dots_f32x4, sum_sq_f32x4);
|
|
660
|
+
// Clamp and sqrt
|
|
661
|
+
dist_sq_f32x4 = vec_max(dist_sq_f32x4, vec_splats(0.0f));
|
|
662
|
+
nk_vf32x4_t dist_f32x4 = vec_sqrt(dist_sq_f32x4);
|
|
663
|
+
results->vf32x4 = dist_f32x4;
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
/** @brief Angular from_dot for i32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
|
|
667
|
+
NK_INTERNAL void nk_angular_through_i32_from_dot_powervsx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
668
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
669
|
+
nk_vi32x4_t dots_i32x4 = dots.vi32x4;
|
|
670
|
+
nk_vf32x4_t dots_f32x4 = vec_ctf(dots_i32x4, 0);
|
|
671
|
+
nk_vf32x4_t query_f32x4 = vec_splats((nk_f32_t)query_sumsq);
|
|
672
|
+
nk_vi32x4_t targets_i32x4 = target_sumsqs.vi32x4;
|
|
673
|
+
nk_vf32x4_t products_f32x4 = vec_mul(query_f32x4, vec_ctf(targets_i32x4, 0));
|
|
674
|
+
nk_vf32x4_t rsqrt_f32x4 = nk_rsqrt_f32x4_powervsx_(products_f32x4);
|
|
675
|
+
nk_vf32x4_t normalized_f32x4 = vec_mul(dots_f32x4, rsqrt_f32x4);
|
|
676
|
+
nk_vf32x4_t angular_f32x4 = vec_sub(vec_splats(1.0f), normalized_f32x4);
|
|
677
|
+
nk_vf32x4_t result_f32x4 = vec_max(angular_f32x4, vec_splats(0.0f));
|
|
678
|
+
results->vf32x4 = result_f32x4;
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
/** @brief Euclidean from_dot for i32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
682
|
+
NK_INTERNAL void nk_euclidean_through_i32_from_dot_powervsx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
683
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
684
|
+
nk_vi32x4_t dots_i32x4 = dots.vi32x4;
|
|
685
|
+
nk_vf32x4_t dots_f32x4 = vec_ctf(dots_i32x4, 0);
|
|
686
|
+
nk_vf32x4_t query_f32x4 = vec_splats((nk_f32_t)query_sumsq);
|
|
687
|
+
nk_vi32x4_t targets_i32x4 = target_sumsqs.vi32x4;
|
|
688
|
+
nk_vf32x4_t sum_sq_f32x4 = vec_add(query_f32x4, vec_ctf(targets_i32x4, 0));
|
|
689
|
+
nk_vf32x4_t dist_sq_f32x4 = vec_madd(vec_splats(-2.0f), dots_f32x4, sum_sq_f32x4);
|
|
690
|
+
dist_sq_f32x4 = vec_max(dist_sq_f32x4, vec_splats(0.0f));
|
|
691
|
+
nk_vf32x4_t dist_f32x4 = vec_sqrt(dist_sq_f32x4);
|
|
692
|
+
results->vf32x4 = dist_f32x4;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
/** @brief Angular from_dot for u32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
|
|
696
|
+
NK_INTERNAL void nk_angular_through_u32_from_dot_powervsx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
697
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
698
|
+
nk_vu32x4_t dots_u32x4 = dots.vu32x4;
|
|
699
|
+
nk_vf32x4_t dots_f32x4 = vec_ctf(dots_u32x4, 0);
|
|
700
|
+
nk_vf32x4_t query_f32x4 = vec_splats((nk_f32_t)query_sumsq);
|
|
701
|
+
nk_vu32x4_t targets_u32x4 = target_sumsqs.vu32x4;
|
|
702
|
+
nk_vf32x4_t products_f32x4 = vec_mul(query_f32x4, vec_ctf(targets_u32x4, 0));
|
|
703
|
+
nk_vf32x4_t rsqrt_f32x4 = nk_rsqrt_f32x4_powervsx_(products_f32x4);
|
|
704
|
+
nk_vf32x4_t normalized_f32x4 = vec_mul(dots_f32x4, rsqrt_f32x4);
|
|
705
|
+
nk_vf32x4_t angular_f32x4 = vec_sub(vec_splats(1.0f), normalized_f32x4);
|
|
706
|
+
nk_vf32x4_t result_f32x4 = vec_max(angular_f32x4, vec_splats(0.0f));
|
|
707
|
+
results->vf32x4 = result_f32x4;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
/** @brief Euclidean from_dot for u32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
711
|
+
NK_INTERNAL void nk_euclidean_through_u32_from_dot_powervsx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
712
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
713
|
+
nk_vu32x4_t dots_u32x4 = dots.vu32x4;
|
|
714
|
+
nk_vf32x4_t dots_f32x4 = vec_ctf(dots_u32x4, 0);
|
|
715
|
+
nk_vf32x4_t query_f32x4 = vec_splats((nk_f32_t)query_sumsq);
|
|
716
|
+
nk_vu32x4_t targets_u32x4 = target_sumsqs.vu32x4;
|
|
717
|
+
nk_vf32x4_t sum_sq_f32x4 = vec_add(query_f32x4, vec_ctf(targets_u32x4, 0));
|
|
718
|
+
nk_vf32x4_t dist_sq_f32x4 = vec_madd(vec_splats(-2.0f), dots_f32x4, sum_sq_f32x4);
|
|
719
|
+
dist_sq_f32x4 = vec_max(dist_sq_f32x4, vec_splats(0.0f));
|
|
720
|
+
nk_vf32x4_t dist_f32x4 = vec_sqrt(dist_sq_f32x4);
|
|
721
|
+
results->vf32x4 = dist_f32x4;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
#pragma endregion I8 and U8 Integers
|
|
725
|
+
|
|
726
|
+
#if defined(__clang__)
|
|
727
|
+
#pragma clang attribute pop
|
|
728
|
+
#elif defined(__GNUC__)
|
|
729
|
+
#pragma GCC pop_options
|
|
730
|
+
#endif
|
|
731
|
+
|
|
732
|
+
#if defined(__cplusplus)
|
|
733
|
+
} // extern "C"
|
|
734
|
+
#endif
|
|
735
|
+
|
|
736
|
+
#endif // NK_TARGET_POWERVSX
|
|
737
|
+
#endif // NK_TARGET_POWER_
|
|
738
|
+
#endif // NK_SPATIAL_POWERVSX_H
|