numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,16 +8,15 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_neonsdot_instructions ARM NEON SDOT/UDOT Instructions (ARMv8.4-DotProd)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
* vaddvq_u32 ADDV (V.4S) 4cy 1/cy 2/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vdotq_s32 SDOT (V.4S, V.16B, V.16B) 3cy @ 2p 3cy @ 4p
|
|
13
|
+
* vdotq_u32 UDOT (V.4S, V.16B, V.16B) 3cy @ 2p 3cy @ 4p
|
|
14
|
+
* vabdq_s8 SABD (V.16B, V.16B, V.16B) 3cy @ 2p 3cy @ 2p
|
|
15
|
+
* vabdq_u8 UABD (V.16B, V.16B, V.16B) 3cy @ 2p 3cy @ 2p
|
|
16
|
+
* vld1q_s8 LD1 (V.16B) 4cy @ 2p 4cy @ 3p
|
|
17
|
+
* vld1q_u8 LD1 (V.16B) 4cy @ 2p 4cy @ 3p
|
|
18
|
+
* vaddvq_s32 ADDV (V.4S) 4cy @ 1p 5cy @ 1p
|
|
19
|
+
* vaddvq_u32 ADDV (V.4S) 4cy @ 1p 5cy @ 1p
|
|
21
20
|
*
|
|
22
21
|
* The ARMv8.4-DotProd extension provides SDOT/UDOT for int8 dot products and SABD/UABD for
|
|
23
22
|
* absolute differences, enabling L2 and angular distance on quantized embeddings.
|
|
@@ -34,6 +33,7 @@
|
|
|
34
33
|
#if NK_TARGET_NEONSDOT
|
|
35
34
|
|
|
36
35
|
#include "numkong/types.h"
|
|
36
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b4x32_serial_`
|
|
37
37
|
#include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
|
|
38
38
|
|
|
39
39
|
#if defined(__cplusplus)
|
|
@@ -195,7 +195,8 @@ NK_PUBLIC void nk_angular_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_siz
|
|
|
195
195
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
196
196
|
}
|
|
197
197
|
|
|
198
|
-
*result = nk_angular_normalize_f32_neon_(dot_product_i32, a_norm_sq_i32,
|
|
198
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
199
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
199
200
|
}
|
|
200
201
|
|
|
201
202
|
NK_PUBLIC void nk_sqeuclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -243,7 +244,174 @@ NK_PUBLIC void nk_angular_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_siz
|
|
|
243
244
|
ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
|
|
244
245
|
}
|
|
245
246
|
|
|
246
|
-
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
247
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
NK_PUBLIC void nk_sqeuclidean_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
251
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
252
|
+
nk_size_t n_bytes = n / 2;
|
|
253
|
+
uint32x4_t d2_u32x4 = vdupq_n_u32(0);
|
|
254
|
+
uint8x16_t a_u8x16, b_u8x16;
|
|
255
|
+
|
|
256
|
+
nk_sqeuclidean_i4_neonsdot_cycle:
|
|
257
|
+
if (n_bytes < 16) {
|
|
258
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
259
|
+
nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
|
|
260
|
+
nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
|
|
261
|
+
a_u8x16 = a_vec.u8x16;
|
|
262
|
+
b_u8x16 = b_vec.u8x16;
|
|
263
|
+
n_bytes = 0;
|
|
264
|
+
}
|
|
265
|
+
else {
|
|
266
|
+
a_u8x16 = vld1q_u8((nk_u8_t const *)a);
|
|
267
|
+
b_u8x16 = vld1q_u8((nk_u8_t const *)b);
|
|
268
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
// Sign-extend low nibbles, compute |a-b|, reinterpret as unsigned for UDOT squaring
|
|
272
|
+
int8x16_t a_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4), 4);
|
|
273
|
+
int8x16_t b_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4), 4);
|
|
274
|
+
int8x16_t a_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4);
|
|
275
|
+
int8x16_t b_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4);
|
|
276
|
+
|
|
277
|
+
uint8x16_t diff_low_u8x16 = vreinterpretq_u8_s8(vabdq_s8(a_low_i8x16, b_low_i8x16));
|
|
278
|
+
uint8x16_t diff_high_u8x16 = vreinterpretq_u8_s8(vabdq_s8(a_high_i8x16, b_high_i8x16));
|
|
279
|
+
d2_u32x4 = vdotq_u32(d2_u32x4, diff_low_u8x16, diff_low_u8x16);
|
|
280
|
+
d2_u32x4 = vdotq_u32(d2_u32x4, diff_high_u8x16, diff_high_u8x16);
|
|
281
|
+
|
|
282
|
+
if (n_bytes) goto nk_sqeuclidean_i4_neonsdot_cycle;
|
|
283
|
+
*result = vaddvq_u32(d2_u32x4);
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
NK_PUBLIC void nk_euclidean_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
287
|
+
nk_u32_t d2;
|
|
288
|
+
nk_sqeuclidean_i4_neonsdot(a, b, n, &d2);
|
|
289
|
+
*result = nk_f32_sqrt_neon((nk_f32_t)d2);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
NK_PUBLIC void nk_angular_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
293
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
294
|
+
nk_size_t n_bytes = n / 2;
|
|
295
|
+
int32x4_t ab_i32x4 = vdupq_n_s32(0);
|
|
296
|
+
int32x4_t a2_i32x4 = vdupq_n_s32(0);
|
|
297
|
+
int32x4_t b2_i32x4 = vdupq_n_s32(0);
|
|
298
|
+
uint8x16_t a_u8x16, b_u8x16;
|
|
299
|
+
|
|
300
|
+
nk_angular_i4_neonsdot_cycle:
|
|
301
|
+
if (n_bytes < 16) {
|
|
302
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
303
|
+
nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
|
|
304
|
+
nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
|
|
305
|
+
a_u8x16 = a_vec.u8x16;
|
|
306
|
+
b_u8x16 = b_vec.u8x16;
|
|
307
|
+
n_bytes = 0;
|
|
308
|
+
}
|
|
309
|
+
else {
|
|
310
|
+
a_u8x16 = vld1q_u8((nk_u8_t const *)a);
|
|
311
|
+
b_u8x16 = vld1q_u8((nk_u8_t const *)b);
|
|
312
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
int8x16_t a_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4), 4);
|
|
316
|
+
int8x16_t b_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4), 4);
|
|
317
|
+
int8x16_t a_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4);
|
|
318
|
+
int8x16_t b_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4);
|
|
319
|
+
|
|
320
|
+
ab_i32x4 = vdotq_s32(ab_i32x4, a_low_i8x16, b_low_i8x16);
|
|
321
|
+
ab_i32x4 = vdotq_s32(ab_i32x4, a_high_i8x16, b_high_i8x16);
|
|
322
|
+
a2_i32x4 = vdotq_s32(a2_i32x4, a_low_i8x16, a_low_i8x16);
|
|
323
|
+
a2_i32x4 = vdotq_s32(a2_i32x4, a_high_i8x16, a_high_i8x16);
|
|
324
|
+
b2_i32x4 = vdotq_s32(b2_i32x4, b_low_i8x16, b_low_i8x16);
|
|
325
|
+
b2_i32x4 = vdotq_s32(b2_i32x4, b_high_i8x16, b_high_i8x16);
|
|
326
|
+
|
|
327
|
+
if (n_bytes) goto nk_angular_i4_neonsdot_cycle;
|
|
328
|
+
|
|
329
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)vaddvq_s32(ab_i32x4), (nk_f32_t)vaddvq_s32(a2_i32x4),
|
|
330
|
+
(nk_f32_t)vaddvq_s32(b2_i32x4));
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
NK_PUBLIC void nk_sqeuclidean_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
334
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
335
|
+
nk_size_t n_bytes = n / 2;
|
|
336
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
337
|
+
uint32x4_t d2_u32x4 = vdupq_n_u32(0);
|
|
338
|
+
uint8x16_t a_u8x16, b_u8x16;
|
|
339
|
+
|
|
340
|
+
nk_sqeuclidean_u4_neonsdot_cycle:
|
|
341
|
+
if (n_bytes < 16) {
|
|
342
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
343
|
+
nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
|
|
344
|
+
nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
|
|
345
|
+
a_u8x16 = a_vec.u8x16;
|
|
346
|
+
b_u8x16 = b_vec.u8x16;
|
|
347
|
+
n_bytes = 0;
|
|
348
|
+
}
|
|
349
|
+
else {
|
|
350
|
+
a_u8x16 = vld1q_u8((nk_u8_t const *)a);
|
|
351
|
+
b_u8x16 = vld1q_u8((nk_u8_t const *)b);
|
|
352
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
uint8x16_t a_low_u8x16 = vandq_u8(a_u8x16, nibble_mask_u8x16);
|
|
356
|
+
uint8x16_t a_high_u8x16 = vshrq_n_u8(a_u8x16, 4);
|
|
357
|
+
uint8x16_t b_low_u8x16 = vandq_u8(b_u8x16, nibble_mask_u8x16);
|
|
358
|
+
uint8x16_t b_high_u8x16 = vshrq_n_u8(b_u8x16, 4);
|
|
359
|
+
|
|
360
|
+
uint8x16_t diff_low_u8x16 = vabdq_u8(a_low_u8x16, b_low_u8x16);
|
|
361
|
+
uint8x16_t diff_high_u8x16 = vabdq_u8(a_high_u8x16, b_high_u8x16);
|
|
362
|
+
d2_u32x4 = vdotq_u32(d2_u32x4, diff_low_u8x16, diff_low_u8x16);
|
|
363
|
+
d2_u32x4 = vdotq_u32(d2_u32x4, diff_high_u8x16, diff_high_u8x16);
|
|
364
|
+
|
|
365
|
+
if (n_bytes) goto nk_sqeuclidean_u4_neonsdot_cycle;
|
|
366
|
+
*result = vaddvq_u32(d2_u32x4);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
NK_PUBLIC void nk_euclidean_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
370
|
+
nk_u32_t d2;
|
|
371
|
+
nk_sqeuclidean_u4_neonsdot(a, b, n, &d2);
|
|
372
|
+
*result = nk_f32_sqrt_neon((nk_f32_t)d2);
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
NK_PUBLIC void nk_angular_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
376
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
377
|
+
nk_size_t n_bytes = n / 2;
|
|
378
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
379
|
+
uint32x4_t ab_u32x4 = vdupq_n_u32(0);
|
|
380
|
+
uint32x4_t a2_u32x4 = vdupq_n_u32(0);
|
|
381
|
+
uint32x4_t b2_u32x4 = vdupq_n_u32(0);
|
|
382
|
+
uint8x16_t a_u8x16, b_u8x16;
|
|
383
|
+
|
|
384
|
+
nk_angular_u4_neonsdot_cycle:
|
|
385
|
+
if (n_bytes < 16) {
|
|
386
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
387
|
+
nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
|
|
388
|
+
nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
|
|
389
|
+
a_u8x16 = a_vec.u8x16;
|
|
390
|
+
b_u8x16 = b_vec.u8x16;
|
|
391
|
+
n_bytes = 0;
|
|
392
|
+
}
|
|
393
|
+
else {
|
|
394
|
+
a_u8x16 = vld1q_u8((nk_u8_t const *)a);
|
|
395
|
+
b_u8x16 = vld1q_u8((nk_u8_t const *)b);
|
|
396
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
uint8x16_t a_low_u8x16 = vandq_u8(a_u8x16, nibble_mask_u8x16);
|
|
400
|
+
uint8x16_t a_high_u8x16 = vshrq_n_u8(a_u8x16, 4);
|
|
401
|
+
uint8x16_t b_low_u8x16 = vandq_u8(b_u8x16, nibble_mask_u8x16);
|
|
402
|
+
uint8x16_t b_high_u8x16 = vshrq_n_u8(b_u8x16, 4);
|
|
403
|
+
|
|
404
|
+
ab_u32x4 = vdotq_u32(ab_u32x4, a_low_u8x16, b_low_u8x16);
|
|
405
|
+
ab_u32x4 = vdotq_u32(ab_u32x4, a_high_u8x16, b_high_u8x16);
|
|
406
|
+
a2_u32x4 = vdotq_u32(a2_u32x4, a_low_u8x16, a_low_u8x16);
|
|
407
|
+
a2_u32x4 = vdotq_u32(a2_u32x4, a_high_u8x16, a_high_u8x16);
|
|
408
|
+
b2_u32x4 = vdotq_u32(b2_u32x4, b_low_u8x16, b_low_u8x16);
|
|
409
|
+
b2_u32x4 = vdotq_u32(b2_u32x4, b_high_u8x16, b_high_u8x16);
|
|
410
|
+
|
|
411
|
+
if (n_bytes) goto nk_angular_u4_neonsdot_cycle;
|
|
412
|
+
|
|
413
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)vaddvq_u32(ab_u32x4), (nk_f32_t)vaddvq_u32(a2_u32x4),
|
|
414
|
+
(nk_f32_t)vaddvq_u32(b2_u32x4));
|
|
247
415
|
}
|
|
248
416
|
|
|
249
417
|
#if defined(__clang__)
|