numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for LoongArch LASX (256-bit).
|
|
3
|
+
* @file include/numkong/spatial/loongsonasx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_loongsonasx_instructions Key LASX Spatial Instructions
|
|
10
|
+
*
|
|
11
|
+
* LASX provides 256-bit SIMD operations using __m256i as the universal vector type.
|
|
12
|
+
* All intrinsics are prefixed with __lasx_. Float operations reinterpret __m256i as
|
|
13
|
+
* f32x8 or f64x4. Integer widening multiply-accumulate chains handle i8/u8 distances.
|
|
14
|
+
*
|
|
15
|
+
* For F32 spatial distances, upcasting to F64 and downcasting back is faster than stable
|
|
16
|
+
* summation algorithms. For F64 angular we use the Dot2 algorithm (Ogita-Rump-Oishi, 2005)
|
|
17
|
+
* for the cross-product accumulation, while self-products use simple FMA since all terms
|
|
18
|
+
* are non-negative and don't suffer from cancellation.
|
|
19
|
+
*/
|
|
20
|
+
#ifndef NK_SPATIAL_LOONGSONASX_H
|
|
21
|
+
#define NK_SPATIAL_LOONGSONASX_H
|
|
22
|
+
|
|
23
|
+
#if NK_TARGET_LOONGARCH_
|
|
24
|
+
#if NK_TARGET_LOONGSONASX
|
|
25
|
+
|
|
26
|
+
#include "numkong/types.h"
|
|
27
|
+
#include "numkong/spatial/serial.h"
|
|
28
|
+
#include "numkong/dot/loongsonasx.h" //
|
|
29
|
+
#include "numkong/cast/loongsonasx.h" // `nk_bf16x8_to_f32x8_loongsonasx_`
|
|
30
|
+
#include "numkong/scalar/loongsonasx.h" // `nk_f32_sqrt_loongsonasx`, `nk_f64_sqrt_loongsonasx`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#pragma region Angular Normalize Helpers
|
|
37
|
+
|
|
38
|
+
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_loongsonasx_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
39
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
40
|
+
else if (ab == 0) return 1;
|
|
41
|
+
nk_f64_t result = 1 - ab / (nk_f64_sqrt_loongsonasx(a2) * nk_f64_sqrt_loongsonasx(b2));
|
|
42
|
+
return result > 0 ? result : 0;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
NK_INTERNAL nk_f32_t nk_angular_normalize_i32_loongsonasx_(nk_i32_t ab, nk_i32_t a2, nk_i32_t b2) {
|
|
46
|
+
if (a2 == 0 && b2 == 0) return 0;
|
|
47
|
+
else if (ab == 0) return 1;
|
|
48
|
+
nk_f32_t result = 1.0f -
|
|
49
|
+
(nk_f32_t)ab * nk_f32_rsqrt_loongsonasx((nk_f32_t)a2) * nk_f32_rsqrt_loongsonasx((nk_f32_t)b2);
|
|
50
|
+
return result > 0 ? result : 0;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
#pragma endregion Angular Normalize Helpers
|
|
54
|
+
|
|
55
|
+
#pragma region I8 and U8 Integers
|
|
56
|
+
|
|
57
|
+
NK_PUBLIC void nk_sqeuclidean_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
58
|
+
__m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
59
|
+
nk_size_t i = 0;
|
|
60
|
+
for (; i + 32 <= n; i += 32) {
|
|
61
|
+
__m256i a_i8x32 = __lasx_xvld(a + i, 0);
|
|
62
|
+
__m256i b_i8x32 = __lasx_xvld(b + i, 0);
|
|
63
|
+
__m256i diff_i8x32 = __lasx_xvsub_b(a_i8x32, b_i8x32);
|
|
64
|
+
__m256i sq_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
65
|
+
sq_i16x16 = __lasx_xvmaddwev_h_b(sq_i16x16, diff_i8x32, diff_i8x32);
|
|
66
|
+
sq_i16x16 = __lasx_xvmaddwod_h_b(sq_i16x16, diff_i8x32, diff_i8x32);
|
|
67
|
+
sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvhaddw_w_h(sq_i16x16, sq_i16x16));
|
|
68
|
+
}
|
|
69
|
+
nk_i32_t sum = nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
|
|
70
|
+
for (; i < n; ++i) {
|
|
71
|
+
nk_i32_t diff = (nk_i32_t)a[i] - b[i];
|
|
72
|
+
sum += diff * diff;
|
|
73
|
+
}
|
|
74
|
+
*result = (nk_u32_t)sum;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
NK_PUBLIC void nk_euclidean_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
78
|
+
nk_u32_t distance_sq_u32;
|
|
79
|
+
nk_sqeuclidean_i8_loongsonasx(a, b, n, &distance_sq_u32);
|
|
80
|
+
*result = nk_f32_sqrt_loongsonasx((nk_f32_t)distance_sq_u32);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
NK_PUBLIC void nk_angular_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
84
|
+
__m256i dot_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
85
|
+
__m256i a_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
86
|
+
__m256i b_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
87
|
+
nk_size_t i = 0;
|
|
88
|
+
for (; i + 32 <= n; i += 32) {
|
|
89
|
+
__m256i a_i8x32 = __lasx_xvld(a + i, 0);
|
|
90
|
+
__m256i b_i8x32 = __lasx_xvld(b + i, 0);
|
|
91
|
+
// dot(a, b)
|
|
92
|
+
__m256i ab_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
93
|
+
ab_i16x16 = __lasx_xvmaddwev_h_b(ab_i16x16, a_i8x32, b_i8x32);
|
|
94
|
+
ab_i16x16 = __lasx_xvmaddwod_h_b(ab_i16x16, a_i8x32, b_i8x32);
|
|
95
|
+
dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvhaddw_w_h(ab_i16x16, ab_i16x16));
|
|
96
|
+
// norm_sq(a)
|
|
97
|
+
__m256i aa_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
98
|
+
aa_i16x16 = __lasx_xvmaddwev_h_b(aa_i16x16, a_i8x32, a_i8x32);
|
|
99
|
+
aa_i16x16 = __lasx_xvmaddwod_h_b(aa_i16x16, a_i8x32, a_i8x32);
|
|
100
|
+
a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvhaddw_w_h(aa_i16x16, aa_i16x16));
|
|
101
|
+
// norm_sq(b)
|
|
102
|
+
__m256i bb_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
103
|
+
bb_i16x16 = __lasx_xvmaddwev_h_b(bb_i16x16, b_i8x32, b_i8x32);
|
|
104
|
+
bb_i16x16 = __lasx_xvmaddwod_h_b(bb_i16x16, b_i8x32, b_i8x32);
|
|
105
|
+
b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvhaddw_w_h(bb_i16x16, bb_i16x16));
|
|
106
|
+
}
|
|
107
|
+
nk_i32_t dot = nk_reduce_add_i32x8_loongsonasx_(dot_i32x8);
|
|
108
|
+
nk_i32_t a_sq = nk_reduce_add_i32x8_loongsonasx_(a_sq_i32x8);
|
|
109
|
+
nk_i32_t b_sq = nk_reduce_add_i32x8_loongsonasx_(b_sq_i32x8);
|
|
110
|
+
for (; i < n; ++i) {
|
|
111
|
+
nk_i32_t a_val = a[i], b_val = b[i];
|
|
112
|
+
dot += a_val * b_val;
|
|
113
|
+
a_sq += a_val * a_val;
|
|
114
|
+
b_sq += b_val * b_val;
|
|
115
|
+
}
|
|
116
|
+
*result = nk_angular_normalize_i32_loongsonasx_(dot, a_sq, b_sq);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
NK_PUBLIC void nk_sqeuclidean_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
120
|
+
__m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
121
|
+
__m256i zeros_i8x32 = __lasx_xvreplgr2vr_b(0);
|
|
122
|
+
nk_size_t i = 0;
|
|
123
|
+
for (; i + 32 <= n; i += 32) {
|
|
124
|
+
__m256i a_u8x32 = __lasx_xvld(a + i, 0);
|
|
125
|
+
__m256i b_u8x32 = __lasx_xvld(b + i, 0);
|
|
126
|
+
__m256i a_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, a_u8x32);
|
|
127
|
+
__m256i a_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, a_u8x32);
|
|
128
|
+
__m256i b_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, b_u8x32);
|
|
129
|
+
__m256i b_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, b_u8x32);
|
|
130
|
+
__m256i diff_low_i16x16 = __lasx_xvsub_h(a_low_u16x16, b_low_u16x16);
|
|
131
|
+
__m256i diff_high_i16x16 = __lasx_xvsub_h(a_high_u16x16, b_high_u16x16);
|
|
132
|
+
__m256i sq_ev_low_i32x8 = __lasx_xvmulwev_w_h(diff_low_i16x16, diff_low_i16x16);
|
|
133
|
+
__m256i sq_od_low_i32x8 = __lasx_xvmulwod_w_h(diff_low_i16x16, diff_low_i16x16);
|
|
134
|
+
__m256i sq_ev_high_i32x8 = __lasx_xvmulwev_w_h(diff_high_i16x16, diff_high_i16x16);
|
|
135
|
+
__m256i sq_od_high_i32x8 = __lasx_xvmulwod_w_h(diff_high_i16x16, diff_high_i16x16);
|
|
136
|
+
sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvadd_w(sq_ev_low_i32x8, sq_od_low_i32x8));
|
|
137
|
+
sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvadd_w(sq_ev_high_i32x8, sq_od_high_i32x8));
|
|
138
|
+
}
|
|
139
|
+
nk_i32_t sum = nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
|
|
140
|
+
for (; i < n; ++i) {
|
|
141
|
+
nk_i32_t diff = (nk_i32_t)a[i] - b[i];
|
|
142
|
+
sum += diff * diff;
|
|
143
|
+
}
|
|
144
|
+
*result = (nk_u32_t)sum;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
NK_PUBLIC void nk_euclidean_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
148
|
+
nk_u32_t distance_sq_u32;
|
|
149
|
+
nk_sqeuclidean_u8_loongsonasx(a, b, n, &distance_sq_u32);
|
|
150
|
+
*result = nk_f32_sqrt_loongsonasx((nk_f32_t)distance_sq_u32);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
NK_PUBLIC void nk_angular_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
154
|
+
__m256i dot_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
155
|
+
__m256i a_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
156
|
+
__m256i b_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
157
|
+
__m256i zeros_i8x32 = __lasx_xvreplgr2vr_b(0);
|
|
158
|
+
nk_size_t i = 0;
|
|
159
|
+
for (; i + 32 <= n; i += 32) {
|
|
160
|
+
__m256i a_u8x32 = __lasx_xvld(a + i, 0);
|
|
161
|
+
__m256i b_u8x32 = __lasx_xvld(b + i, 0);
|
|
162
|
+
__m256i a_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, a_u8x32);
|
|
163
|
+
__m256i a_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, a_u8x32);
|
|
164
|
+
__m256i b_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, b_u8x32);
|
|
165
|
+
__m256i b_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, b_u8x32);
|
|
166
|
+
// dot(a, b)
|
|
167
|
+
__m256i ab_ev_low_i32x8 = __lasx_xvmulwev_w_h(a_low_u16x16, b_low_u16x16);
|
|
168
|
+
__m256i ab_od_low_i32x8 = __lasx_xvmulwod_w_h(a_low_u16x16, b_low_u16x16);
|
|
169
|
+
__m256i ab_ev_high_i32x8 = __lasx_xvmulwev_w_h(a_high_u16x16, b_high_u16x16);
|
|
170
|
+
__m256i ab_od_high_i32x8 = __lasx_xvmulwod_w_h(a_high_u16x16, b_high_u16x16);
|
|
171
|
+
dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvadd_w(ab_ev_low_i32x8, ab_od_low_i32x8));
|
|
172
|
+
dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvadd_w(ab_ev_high_i32x8, ab_od_high_i32x8));
|
|
173
|
+
// norm_sq(a)
|
|
174
|
+
__m256i aa_ev_low_i32x8 = __lasx_xvmulwev_w_h(a_low_u16x16, a_low_u16x16);
|
|
175
|
+
__m256i aa_od_low_i32x8 = __lasx_xvmulwod_w_h(a_low_u16x16, a_low_u16x16);
|
|
176
|
+
__m256i aa_ev_high_i32x8 = __lasx_xvmulwev_w_h(a_high_u16x16, a_high_u16x16);
|
|
177
|
+
__m256i aa_od_high_i32x8 = __lasx_xvmulwod_w_h(a_high_u16x16, a_high_u16x16);
|
|
178
|
+
a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvadd_w(aa_ev_low_i32x8, aa_od_low_i32x8));
|
|
179
|
+
a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvadd_w(aa_ev_high_i32x8, aa_od_high_i32x8));
|
|
180
|
+
// norm_sq(b)
|
|
181
|
+
__m256i bb_ev_low_i32x8 = __lasx_xvmulwev_w_h(b_low_u16x16, b_low_u16x16);
|
|
182
|
+
__m256i bb_od_low_i32x8 = __lasx_xvmulwod_w_h(b_low_u16x16, b_low_u16x16);
|
|
183
|
+
__m256i bb_ev_high_i32x8 = __lasx_xvmulwev_w_h(b_high_u16x16, b_high_u16x16);
|
|
184
|
+
__m256i bb_od_high_i32x8 = __lasx_xvmulwod_w_h(b_high_u16x16, b_high_u16x16);
|
|
185
|
+
b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvadd_w(bb_ev_low_i32x8, bb_od_low_i32x8));
|
|
186
|
+
b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvadd_w(bb_ev_high_i32x8, bb_od_high_i32x8));
|
|
187
|
+
}
|
|
188
|
+
nk_i32_t dot = nk_reduce_add_i32x8_loongsonasx_(dot_i32x8);
|
|
189
|
+
nk_i32_t a_sq = nk_reduce_add_i32x8_loongsonasx_(a_sq_i32x8);
|
|
190
|
+
nk_i32_t b_sq = nk_reduce_add_i32x8_loongsonasx_(b_sq_i32x8);
|
|
191
|
+
for (; i < n; ++i) {
|
|
192
|
+
nk_i32_t a_val = a[i], b_val = b[i];
|
|
193
|
+
dot += a_val * b_val;
|
|
194
|
+
a_sq += a_val * a_val;
|
|
195
|
+
b_sq += b_val * b_val;
|
|
196
|
+
}
|
|
197
|
+
*result = nk_angular_normalize_i32_loongsonasx_(dot, a_sq, b_sq);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
#pragma endregion I8 and U8 Integers
|
|
201
|
+
|
|
202
|
+
#pragma region F32 and F64 Floats
|
|
203
|
+
|
|
204
|
+
NK_PUBLIC void nk_sqeuclidean_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
205
|
+
__m256d sum_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
206
|
+
__m256d sum_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
207
|
+
nk_size_t i = 0;
|
|
208
|
+
for (; i + 8 <= n; i += 8) {
|
|
209
|
+
__m256i a_f32x8 = __lasx_xvld(a + i, 0);
|
|
210
|
+
__m256i b_f32x8 = __lasx_xvld(b + i, 0);
|
|
211
|
+
__m256d a_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)a_f32x8);
|
|
212
|
+
__m256d b_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)b_f32x8);
|
|
213
|
+
__m256d a_high_f64x4 = __lasx_xvfcvth_d_s((__m256)a_f32x8);
|
|
214
|
+
__m256d b_high_f64x4 = __lasx_xvfcvth_d_s((__m256)b_f32x8);
|
|
215
|
+
__m256d diff_low_f64x4 = __lasx_xvfsub_d(a_low_f64x4, b_low_f64x4);
|
|
216
|
+
__m256d diff_high_f64x4 = __lasx_xvfsub_d(a_high_f64x4, b_high_f64x4);
|
|
217
|
+
sum_f64x4_low = __lasx_xvfmadd_d(diff_low_f64x4, diff_low_f64x4, sum_f64x4_low);
|
|
218
|
+
sum_f64x4_high = __lasx_xvfmadd_d(diff_high_f64x4, diff_high_f64x4, sum_f64x4_high);
|
|
219
|
+
}
|
|
220
|
+
__m256d combined_f64x4 = __lasx_xvfadd_d(sum_f64x4_low, sum_f64x4_high);
|
|
221
|
+
nk_f64_t sum = nk_reduce_add_f64x4_loongsonasx_(combined_f64x4);
|
|
222
|
+
for (; i < n; ++i) {
|
|
223
|
+
nk_f64_t diff = (nk_f64_t)a[i] - b[i];
|
|
224
|
+
sum += diff * diff;
|
|
225
|
+
}
|
|
226
|
+
*result = sum;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
NK_PUBLIC void nk_euclidean_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
230
|
+
nk_sqeuclidean_f32_loongsonasx(a, b, n, result);
|
|
231
|
+
*result = nk_f64_sqrt_loongsonasx(*result);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
NK_PUBLIC void nk_angular_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
235
|
+
__m256d dot_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
236
|
+
__m256d dot_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
237
|
+
__m256d a_sq_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
238
|
+
__m256d a_sq_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
239
|
+
__m256d b_sq_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
240
|
+
__m256d b_sq_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
241
|
+
nk_size_t i = 0;
|
|
242
|
+
for (; i + 8 <= n; i += 8) {
|
|
243
|
+
__m256i a_f32x8 = __lasx_xvld(a + i, 0);
|
|
244
|
+
__m256i b_f32x8 = __lasx_xvld(b + i, 0);
|
|
245
|
+
__m256d a_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)a_f32x8);
|
|
246
|
+
__m256d b_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)b_f32x8);
|
|
247
|
+
__m256d a_high_f64x4 = __lasx_xvfcvth_d_s((__m256)a_f32x8);
|
|
248
|
+
__m256d b_high_f64x4 = __lasx_xvfcvth_d_s((__m256)b_f32x8);
|
|
249
|
+
dot_f64x4_low = __lasx_xvfmadd_d(a_low_f64x4, b_low_f64x4, dot_f64x4_low);
|
|
250
|
+
dot_f64x4_high = __lasx_xvfmadd_d(a_high_f64x4, b_high_f64x4, dot_f64x4_high);
|
|
251
|
+
a_sq_f64x4_low = __lasx_xvfmadd_d(a_low_f64x4, a_low_f64x4, a_sq_f64x4_low);
|
|
252
|
+
a_sq_f64x4_high = __lasx_xvfmadd_d(a_high_f64x4, a_high_f64x4, a_sq_f64x4_high);
|
|
253
|
+
b_sq_f64x4_low = __lasx_xvfmadd_d(b_low_f64x4, b_low_f64x4, b_sq_f64x4_low);
|
|
254
|
+
b_sq_f64x4_high = __lasx_xvfmadd_d(b_high_f64x4, b_high_f64x4, b_sq_f64x4_high);
|
|
255
|
+
}
|
|
256
|
+
nk_f64_t dot = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(dot_f64x4_low, dot_f64x4_high));
|
|
257
|
+
nk_f64_t a_sq = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(a_sq_f64x4_low, a_sq_f64x4_high));
|
|
258
|
+
nk_f64_t b_sq = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(b_sq_f64x4_low, b_sq_f64x4_high));
|
|
259
|
+
for (; i < n; ++i) {
|
|
260
|
+
nk_f64_t a_val = a[i], b_val = b[i];
|
|
261
|
+
dot += a_val * b_val;
|
|
262
|
+
a_sq += a_val * a_val;
|
|
263
|
+
b_sq += b_val * b_val;
|
|
264
|
+
}
|
|
265
|
+
*result = nk_angular_normalize_f64_loongsonasx_(dot, a_sq, b_sq);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
NK_PUBLIC void nk_sqeuclidean_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
269
|
+
__m256d sum_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
270
|
+
nk_size_t i = 0;
|
|
271
|
+
for (; i + 4 <= n; i += 4) {
|
|
272
|
+
__m256d a_f64x4 = (__m256d)__lasx_xvld(a + i, 0);
|
|
273
|
+
__m256d b_f64x4 = (__m256d)__lasx_xvld(b + i, 0);
|
|
274
|
+
__m256d diff_f64x4 = __lasx_xvfsub_d(a_f64x4, b_f64x4);
|
|
275
|
+
sum_f64x4 = __lasx_xvfmadd_d(diff_f64x4, diff_f64x4, sum_f64x4);
|
|
276
|
+
}
|
|
277
|
+
nk_f64_t sum = nk_reduce_add_f64x4_loongsonasx_(sum_f64x4);
|
|
278
|
+
for (; i < n; ++i) {
|
|
279
|
+
nk_f64_t diff = a[i] - b[i];
|
|
280
|
+
sum += diff * diff;
|
|
281
|
+
}
|
|
282
|
+
*result = sum;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
NK_PUBLIC void nk_euclidean_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
286
|
+
nk_sqeuclidean_f64_loongsonasx(a, b, n, result);
|
|
287
|
+
*result = nk_f64_sqrt_loongsonasx(*result);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
NK_PUBLIC void nk_angular_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
291
|
+
__m256d dot_sum_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
292
|
+
__m256d dot_compensation_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
293
|
+
__m256d a_norm_sq_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
294
|
+
__m256d b_norm_sq_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
295
|
+
nk_size_t i = 0;
|
|
296
|
+
for (; i + 4 <= n; i += 4) {
|
|
297
|
+
__m256d a_f64x4 = (__m256d)__lasx_xvld(a + i, 0);
|
|
298
|
+
__m256d b_f64x4 = (__m256d)__lasx_xvld(b + i, 0);
|
|
299
|
+
|
|
300
|
+
__m256d product_f64x4 = __lasx_xvfmul_d(a_f64x4, b_f64x4);
|
|
301
|
+
__m256d product_error_f64x4 = __lasx_xvfmsub_d(a_f64x4, b_f64x4, product_f64x4);
|
|
302
|
+
|
|
303
|
+
__m256d tentative_sum_f64x4 = __lasx_xvfadd_d(dot_sum_f64x4, product_f64x4);
|
|
304
|
+
__m256d virtual_addend_f64x4 = __lasx_xvfsub_d(tentative_sum_f64x4, dot_sum_f64x4);
|
|
305
|
+
__m256d sum_error_f64x4 = __lasx_xvfadd_d(
|
|
306
|
+
__lasx_xvfsub_d(dot_sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
307
|
+
__lasx_xvfsub_d(product_f64x4, virtual_addend_f64x4));
|
|
308
|
+
|
|
309
|
+
dot_sum_f64x4 = tentative_sum_f64x4;
|
|
310
|
+
dot_compensation_f64x4 = __lasx_xvfadd_d(dot_compensation_f64x4,
|
|
311
|
+
__lasx_xvfadd_d(sum_error_f64x4, product_error_f64x4));
|
|
312
|
+
|
|
313
|
+
a_norm_sq_f64x4 = __lasx_xvfmadd_d(a_f64x4, a_f64x4, a_norm_sq_f64x4);
|
|
314
|
+
b_norm_sq_f64x4 = __lasx_xvfmadd_d(b_f64x4, b_f64x4, b_norm_sq_f64x4);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
nk_f64_t dot = nk_dot_stable_sum_f64x4_loongsonasx_(dot_sum_f64x4, dot_compensation_f64x4);
|
|
318
|
+
nk_f64_t a_sq = nk_reduce_add_f64x4_loongsonasx_(a_norm_sq_f64x4);
|
|
319
|
+
nk_f64_t b_sq = nk_reduce_add_f64x4_loongsonasx_(b_norm_sq_f64x4);
|
|
320
|
+
for (; i < n; ++i) {
|
|
321
|
+
nk_f64_t a_val = a[i], b_val = b[i];
|
|
322
|
+
dot += a_val * b_val;
|
|
323
|
+
a_sq += a_val * a_val;
|
|
324
|
+
b_sq += b_val * b_val;
|
|
325
|
+
}
|
|
326
|
+
*result = nk_angular_normalize_f64_loongsonasx_(dot, a_sq, b_sq);
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
#pragma endregion F32 and F64 Floats
|
|
330
|
+
|
|
331
|
+
#pragma region F16 and BF16 Floats
|
|
332
|
+
|
|
333
|
+
NK_INTERNAL nk_f32_t nk_angular_normalize_f32_loongsonasx_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
|
|
334
|
+
if (a2 == 0.0f && b2 == 0.0f) return 0.0f;
|
|
335
|
+
else if (ab == 0.0f) return 1.0f;
|
|
336
|
+
nk_f32_t result = 1.0f - ab * nk_f32_rsqrt_loongsonasx(a2) * nk_f32_rsqrt_loongsonasx(b2);
|
|
337
|
+
return result > 0.0f ? result : 0.0f;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
/** @brief Horizontal sum of 8 × f32 lanes in a 256-bit LASX register. */
|
|
341
|
+
NK_INTERNAL nk_f32_t nk_reduce_add_f32x8_loongsonasx_(__m256 sum_f32x8) {
|
|
342
|
+
// Add high 128-bit lane to low 128-bit lane
|
|
343
|
+
__m256 high_f32x4 = (__m256)__lasx_xvpermi_q((__m256i)sum_f32x8, (__m256i)sum_f32x8, 0x11);
|
|
344
|
+
__m256 sum_f32x4 = __lasx_xvfadd_s(sum_f32x8, high_f32x4);
|
|
345
|
+
__m256 swapped_f32x4 = (__m256)__lasx_xvshuf4i_w((__m256i)sum_f32x4, 0b01001110);
|
|
346
|
+
__m256 reduced_f32x4 = __lasx_xvfadd_s(sum_f32x4, swapped_f32x4);
|
|
347
|
+
__m256 swapped_f32x2 = (__m256)__lasx_xvshuf4i_w((__m256i)reduced_f32x4, 0b10110001);
|
|
348
|
+
__m256 reduced_f32x2 = __lasx_xvfadd_s(reduced_f32x4, swapped_f32x2);
|
|
349
|
+
nk_fui32_t c;
|
|
350
|
+
c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)reduced_f32x2, 0);
|
|
351
|
+
return c.f;
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
355
|
+
__m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
356
|
+
__m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
|
|
357
|
+
nk_size_t i = 0;
|
|
358
|
+
for (; i + 16 <= n; i += 16) {
|
|
359
|
+
__m256i a_bf16x16 = __lasx_xvld(a + i, 0);
|
|
360
|
+
__m256i b_bf16x16 = __lasx_xvld(b + i, 0);
|
|
361
|
+
__m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a_bf16x16, 16);
|
|
362
|
+
__m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b_bf16x16, 16);
|
|
363
|
+
__m256 diff_even_f32x8 = __lasx_xvfsub_s(a_even_f32x8, b_even_f32x8);
|
|
364
|
+
sum_f32x8 = __lasx_xvfmadd_s(diff_even_f32x8, diff_even_f32x8, sum_f32x8);
|
|
365
|
+
__m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a_bf16x16, mask_high_u32x8);
|
|
366
|
+
__m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b_bf16x16, mask_high_u32x8);
|
|
367
|
+
__m256 diff_odd_f32x8 = __lasx_xvfsub_s(a_odd_f32x8, b_odd_f32x8);
|
|
368
|
+
sum_f32x8 = __lasx_xvfmadd_s(diff_odd_f32x8, diff_odd_f32x8, sum_f32x8);
|
|
369
|
+
}
|
|
370
|
+
nk_f32_t sum = nk_reduce_add_f32x8_loongsonasx_(sum_f32x8);
|
|
371
|
+
for (; i < n; ++i) {
|
|
372
|
+
nk_f32_t a_val, b_val;
|
|
373
|
+
nk_bf16_to_f32_serial(&a[i], &a_val);
|
|
374
|
+
nk_bf16_to_f32_serial(&b[i], &b_val);
|
|
375
|
+
nk_f32_t diff = a_val - b_val;
|
|
376
|
+
sum += diff * diff;
|
|
377
|
+
}
|
|
378
|
+
*result = sum;
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
NK_PUBLIC void nk_euclidean_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
382
|
+
nk_sqeuclidean_bf16_loongsonasx(a, b, n, result);
|
|
383
|
+
*result = nk_f32_sqrt_loongsonasx(*result);
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
NK_PUBLIC void nk_angular_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
387
|
+
__m256 dot_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
388
|
+
__m256 a_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
389
|
+
__m256 b_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
390
|
+
__m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
|
|
391
|
+
nk_size_t i = 0;
|
|
392
|
+
for (; i + 16 <= n; i += 16) {
|
|
393
|
+
__m256i a_bf16x16 = __lasx_xvld(a + i, 0);
|
|
394
|
+
__m256i b_bf16x16 = __lasx_xvld(b + i, 0);
|
|
395
|
+
__m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a_bf16x16, 16);
|
|
396
|
+
__m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b_bf16x16, 16);
|
|
397
|
+
dot_f32x8 = __lasx_xvfmadd_s(a_even_f32x8, b_even_f32x8, dot_f32x8);
|
|
398
|
+
a_sq_f32x8 = __lasx_xvfmadd_s(a_even_f32x8, a_even_f32x8, a_sq_f32x8);
|
|
399
|
+
b_sq_f32x8 = __lasx_xvfmadd_s(b_even_f32x8, b_even_f32x8, b_sq_f32x8);
|
|
400
|
+
__m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a_bf16x16, mask_high_u32x8);
|
|
401
|
+
__m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b_bf16x16, mask_high_u32x8);
|
|
402
|
+
dot_f32x8 = __lasx_xvfmadd_s(a_odd_f32x8, b_odd_f32x8, dot_f32x8);
|
|
403
|
+
a_sq_f32x8 = __lasx_xvfmadd_s(a_odd_f32x8, a_odd_f32x8, a_sq_f32x8);
|
|
404
|
+
b_sq_f32x8 = __lasx_xvfmadd_s(b_odd_f32x8, b_odd_f32x8, b_sq_f32x8);
|
|
405
|
+
}
|
|
406
|
+
nk_f32_t dot = nk_reduce_add_f32x8_loongsonasx_(dot_f32x8);
|
|
407
|
+
nk_f32_t a_sq = nk_reduce_add_f32x8_loongsonasx_(a_sq_f32x8);
|
|
408
|
+
nk_f32_t b_sq = nk_reduce_add_f32x8_loongsonasx_(b_sq_f32x8);
|
|
409
|
+
for (; i < n; ++i) {
|
|
410
|
+
nk_f32_t a_val, b_val;
|
|
411
|
+
nk_bf16_to_f32_serial(&a[i], &a_val);
|
|
412
|
+
nk_bf16_to_f32_serial(&b[i], &b_val);
|
|
413
|
+
dot += a_val * b_val;
|
|
414
|
+
a_sq += a_val * a_val;
|
|
415
|
+
b_sq += b_val * b_val;
|
|
416
|
+
}
|
|
417
|
+
*result = nk_angular_normalize_f32_loongsonasx_(dot, a_sq, b_sq);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
NK_PUBLIC void nk_sqeuclidean_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
421
|
+
__m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
422
|
+
nk_size_t i = 0;
|
|
423
|
+
for (; i + 8 <= n; i += 8) {
|
|
424
|
+
__m128i a_f16x8 = __lsx_vld(a + i, 0);
|
|
425
|
+
__m128i b_f16x8 = __lsx_vld(b + i, 0);
|
|
426
|
+
__m256 a_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(a_f16x8);
|
|
427
|
+
__m256 b_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(b_f16x8);
|
|
428
|
+
__m256 diff_f32x8 = __lasx_xvfsub_s(a_f32x8, b_f32x8);
|
|
429
|
+
sum_f32x8 = __lasx_xvfmadd_s(diff_f32x8, diff_f32x8, sum_f32x8);
|
|
430
|
+
}
|
|
431
|
+
nk_f32_t sum = nk_reduce_add_f32x8_loongsonasx_(sum_f32x8);
|
|
432
|
+
for (; i < n; ++i) {
|
|
433
|
+
nk_f32_t a_val, b_val;
|
|
434
|
+
nk_f16_to_f32_serial(&a[i], &a_val);
|
|
435
|
+
nk_f16_to_f32_serial(&b[i], &b_val);
|
|
436
|
+
nk_f32_t diff = a_val - b_val;
|
|
437
|
+
sum += diff * diff;
|
|
438
|
+
}
|
|
439
|
+
*result = sum;
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
NK_PUBLIC void nk_euclidean_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
443
|
+
nk_sqeuclidean_f16_loongsonasx(a, b, n, result);
|
|
444
|
+
*result = nk_f32_sqrt_loongsonasx(*result);
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
NK_PUBLIC void nk_angular_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
448
|
+
__m256 dot_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
449
|
+
__m256 a_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
450
|
+
__m256 b_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
451
|
+
nk_size_t i = 0;
|
|
452
|
+
for (; i + 8 <= n; i += 8) {
|
|
453
|
+
__m128i a_f16x8 = __lsx_vld(a + i, 0);
|
|
454
|
+
__m128i b_f16x8 = __lsx_vld(b + i, 0);
|
|
455
|
+
__m256 a_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(a_f16x8);
|
|
456
|
+
__m256 b_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(b_f16x8);
|
|
457
|
+
dot_f32x8 = __lasx_xvfmadd_s(a_f32x8, b_f32x8, dot_f32x8);
|
|
458
|
+
a_sq_f32x8 = __lasx_xvfmadd_s(a_f32x8, a_f32x8, a_sq_f32x8);
|
|
459
|
+
b_sq_f32x8 = __lasx_xvfmadd_s(b_f32x8, b_f32x8, b_sq_f32x8);
|
|
460
|
+
}
|
|
461
|
+
nk_f32_t dot = nk_reduce_add_f32x8_loongsonasx_(dot_f32x8);
|
|
462
|
+
nk_f32_t a_sq = nk_reduce_add_f32x8_loongsonasx_(a_sq_f32x8);
|
|
463
|
+
nk_f32_t b_sq = nk_reduce_add_f32x8_loongsonasx_(b_sq_f32x8);
|
|
464
|
+
for (; i < n; ++i) {
|
|
465
|
+
nk_f32_t a_val, b_val;
|
|
466
|
+
nk_f16_to_f32_serial(&a[i], &a_val);
|
|
467
|
+
nk_f16_to_f32_serial(&b[i], &b_val);
|
|
468
|
+
dot += a_val * b_val;
|
|
469
|
+
a_sq += a_val * a_val;
|
|
470
|
+
b_sq += b_val * b_val;
|
|
471
|
+
}
|
|
472
|
+
*result = nk_angular_normalize_f32_loongsonasx_(dot, a_sq, b_sq);
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
#pragma endregion F16 and BF16 Floats
|
|
476
|
+
|
|
477
|
+
#if defined(__cplusplus)
|
|
478
|
+
} // extern "C"
|
|
479
|
+
#endif
|
|
480
|
+
|
|
481
|
+
#endif // NK_TARGET_LOONGSONASX
|
|
482
|
+
#endif // NK_TARGET_LOONGARCH_
|
|
483
|
+
#endif // NK_SPATIAL_LOONGSONASX_H
|