numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,15 +8,15 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_svehalf_instructions ARM SVE+FP16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_f16
|
|
13
|
-
* svsub_f16_x
|
|
14
|
-
* svmla_f16_x
|
|
15
|
-
* svaddv_f16
|
|
16
|
-
* svdupq_n_f16
|
|
17
|
-
* svwhilelt_b16
|
|
18
|
-
* svptrue_b16
|
|
19
|
-
* svcnth
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svsub_f16_x FSUB (Z.H, P/M, Z.H, Z.H) 3cy @ 2p
|
|
14
|
+
* svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
|
|
15
|
+
* svaddv_f16 FADDV (H, P, Z.H) 6cy @ 1p
|
|
16
|
+
* svdupq_n_f16 DUP (Z.H, #imm) 1cy @ 2p
|
|
17
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
|
|
18
|
+
* svptrue_b16 PTRUE (P.H, pattern) 1cy @ 2p
|
|
19
|
+
* svcnth CNTH (Xd) 1cy @ 2p
|
|
20
20
|
*
|
|
21
21
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
22
22
|
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
@@ -52,14 +52,27 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
|
|
|
52
52
|
nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
|
|
53
53
|
nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
|
|
54
54
|
do {
|
|
55
|
-
svbool_t
|
|
56
|
-
svfloat16_t a_f16x = svld1_f16(
|
|
57
|
-
svfloat16_t b_f16x = svld1_f16(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
55
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
|
|
56
|
+
svfloat16_t a_f16x = svld1_f16(predicate_b16x, a + i);
|
|
57
|
+
svfloat16_t b_f16x = svld1_f16(predicate_b16x, b + i);
|
|
58
|
+
nk_size_t remaining = n - i < svcnth() ? n - i : svcnth();
|
|
59
|
+
|
|
60
|
+
// SVE `svcvt_f32_f16_x` widens only even-indexed f16 elements (0, 2, 4, ...),
|
|
61
|
+
// so we need two passes: one on the original vector (even elements) and one on
|
|
62
|
+
// a vector shifted by one position via `svext` (odd elements become even).
|
|
63
|
+
svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
|
|
64
|
+
svfloat32_t a_even_f32x = svcvt_f32_f16_x(pred_even_b32x, a_f16x);
|
|
65
|
+
svfloat32_t b_even_f32x = svcvt_f32_f16_x(pred_even_b32x, b_f16x);
|
|
66
|
+
svfloat32_t diff_even_f32x = svsub_f32_x(pred_even_b32x, a_even_f32x, b_even_f32x);
|
|
67
|
+
d2_f32x = svmla_f32_m(pred_even_b32x, d2_f32x, diff_even_f32x, diff_even_f32x);
|
|
68
|
+
|
|
69
|
+
svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
|
|
70
|
+
svfloat32_t a_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1));
|
|
71
|
+
svfloat32_t b_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1));
|
|
72
|
+
svfloat32_t diff_odd_f32x = svsub_f32_x(pred_odd_b32x, a_odd_f32x, b_odd_f32x);
|
|
73
|
+
d2_f32x = svmla_f32_m(pred_odd_b32x, d2_f32x, diff_odd_f32x, diff_odd_f32x);
|
|
74
|
+
|
|
75
|
+
i += svcnth();
|
|
63
76
|
} while (i < n);
|
|
64
77
|
*result = svaddv_f32(svptrue_b32(), d2_f32x);
|
|
65
78
|
}
|
|
@@ -77,15 +90,28 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
|
|
|
77
90
|
nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
|
|
78
91
|
nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
|
|
79
92
|
do {
|
|
80
|
-
svbool_t
|
|
81
|
-
svfloat16_t a_f16x = svld1_f16(
|
|
82
|
-
svfloat16_t b_f16x = svld1_f16(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
93
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
|
|
94
|
+
svfloat16_t a_f16x = svld1_f16(predicate_b16x, a + i);
|
|
95
|
+
svfloat16_t b_f16x = svld1_f16(predicate_b16x, b + i);
|
|
96
|
+
nk_size_t remaining = n - i < svcnth() ? n - i : svcnth();
|
|
97
|
+
|
|
98
|
+
// Even-indexed f16 elements (0, 2, 4, ...)
|
|
99
|
+
svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
|
|
100
|
+
svfloat32_t a_even_f32x = svcvt_f32_f16_x(pred_even_b32x, a_f16x);
|
|
101
|
+
svfloat32_t b_even_f32x = svcvt_f32_f16_x(pred_even_b32x, b_f16x);
|
|
102
|
+
ab_f32x = svmla_f32_m(pred_even_b32x, ab_f32x, a_even_f32x, b_even_f32x);
|
|
103
|
+
a2_f32x = svmla_f32_m(pred_even_b32x, a2_f32x, a_even_f32x, a_even_f32x);
|
|
104
|
+
b2_f32x = svmla_f32_m(pred_even_b32x, b2_f32x, b_even_f32x, b_even_f32x);
|
|
105
|
+
|
|
106
|
+
// Odd-indexed f16 elements (1, 3, 5, ...) via svext shift-by-1
|
|
107
|
+
svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
|
|
108
|
+
svfloat32_t a_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1));
|
|
109
|
+
svfloat32_t b_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1));
|
|
110
|
+
ab_f32x = svmla_f32_m(pred_odd_b32x, ab_f32x, a_odd_f32x, b_odd_f32x);
|
|
111
|
+
a2_f32x = svmla_f32_m(pred_odd_b32x, a2_f32x, a_odd_f32x, a_odd_f32x);
|
|
112
|
+
b2_f32x = svmla_f32_m(pred_odd_b32x, b2_f32x, b_odd_f32x, b_odd_f32x);
|
|
113
|
+
|
|
114
|
+
i += svcnth();
|
|
89
115
|
} while (i < n);
|
|
90
116
|
|
|
91
117
|
nk_f32_t ab_f32 = svaddv_f32(svptrue_b32(), ab_f32x);
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for SVE SDOT.
|
|
3
|
+
* @file include/numkong/spatial/svesdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 3, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_svesdot_instructions ARM SVE+DotProd Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_s8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svld1_u8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
|
|
14
|
+
* svdot_s32 SDOT (Z.S, Z.B, Z.B) 3cy @ 2p
|
|
15
|
+
* svdot_u32 UDOT (Z.S, Z.B, Z.B) 3cy @ 2p
|
|
16
|
+
* svabd_s8_x SABD (Z.B, P/M, Z.B) 3cy @ 2p
|
|
17
|
+
* svabd_u8_x UABD (Z.B, P/M, Z.B) 3cy @ 2p
|
|
18
|
+
* svaddv_s32 SADDV (D, P, Z.S) 6cy @ 1p
|
|
19
|
+
* svaddv_u32 UADDV (D, P, Z.S) 6cy @ 1p
|
|
20
|
+
* svwhilelt_b8 WHILELT (P.B, Xn, Xm) 2cy @ 1p
|
|
21
|
+
* svcntb CNTB (Xd) 1cy @ 2p
|
|
22
|
+
*
|
|
23
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
24
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
25
|
+
* process more elements per iteration with identical latencies.
|
|
26
|
+
*
|
|
27
|
+
* For L2 distance, SABD/UABD computes |a-b| per byte, then UDOT squares and accumulates.
|
|
28
|
+
* Angular distance uses SDOT/UDOT directly for dot product and norm computations.
|
|
29
|
+
*/
|
|
30
|
+
#ifndef NK_SPATIAL_SVESDOT_H
|
|
31
|
+
#define NK_SPATIAL_SVESDOT_H
|
|
32
|
+
|
|
33
|
+
#if NK_TARGET_ARM_
|
|
34
|
+
#if NK_TARGET_SVESDOT
|
|
35
|
+
|
|
36
|
+
#include "numkong/types.h"
|
|
37
|
+
#include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
|
|
38
|
+
|
|
39
|
+
#if defined(__cplusplus)
|
|
40
|
+
extern "C" {
|
|
41
|
+
#endif
|
|
42
|
+
|
|
43
|
+
#if defined(__clang__)
|
|
44
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+dotprod"))), apply_to = function)
|
|
45
|
+
#elif defined(__GNUC__)
|
|
46
|
+
#pragma GCC push_options
|
|
47
|
+
#pragma GCC target("arch=armv8.2-a+sve+dotprod")
|
|
48
|
+
#endif
|
|
49
|
+
|
|
50
|
+
NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
51
|
+
nk_size_t i = 0;
|
|
52
|
+
svuint32_t distance_sq_u32x = svdup_u32(0);
|
|
53
|
+
do {
|
|
54
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
|
|
55
|
+
svint8_t a_i8x = svld1_s8(predicate_b8x, a + i);
|
|
56
|
+
svint8_t b_i8x = svld1_s8(predicate_b8x, b + i);
|
|
57
|
+
svuint8_t diff_u8x = svreinterpret_u8_s8(svabd_s8_x(predicate_b8x, a_i8x, b_i8x));
|
|
58
|
+
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
59
|
+
i += svcntb();
|
|
60
|
+
} while (i < n);
|
|
61
|
+
*result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
|
|
62
|
+
}
|
|
63
|
+
NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
64
|
+
nk_u32_t distance_sq_u32;
|
|
65
|
+
nk_sqeuclidean_i8_svesdot(a, b, n, &distance_sq_u32);
|
|
66
|
+
*result = nk_f32_sqrt_neon((nk_f32_t)distance_sq_u32);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
70
|
+
nk_size_t i = 0;
|
|
71
|
+
svint32_t ab_i32x = svdup_s32(0);
|
|
72
|
+
svint32_t a2_i32x = svdup_s32(0);
|
|
73
|
+
svint32_t b2_i32x = svdup_s32(0);
|
|
74
|
+
do {
|
|
75
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
|
|
76
|
+
svint8_t a_i8x = svld1_s8(predicate_b8x, a + i);
|
|
77
|
+
svint8_t b_i8x = svld1_s8(predicate_b8x, b + i);
|
|
78
|
+
ab_i32x = svdot_s32(ab_i32x, a_i8x, b_i8x);
|
|
79
|
+
a2_i32x = svdot_s32(a2_i32x, a_i8x, a_i8x);
|
|
80
|
+
b2_i32x = svdot_s32(b2_i32x, b_i8x, b_i8x);
|
|
81
|
+
i += svcntb();
|
|
82
|
+
} while (i < n);
|
|
83
|
+
|
|
84
|
+
nk_i32_t ab = (nk_i32_t)svaddv_s32(svptrue_b32(), ab_i32x);
|
|
85
|
+
nk_i32_t a2 = (nk_i32_t)svaddv_s32(svptrue_b32(), a2_i32x);
|
|
86
|
+
nk_i32_t b2 = (nk_i32_t)svaddv_s32(svptrue_b32(), b2_i32x);
|
|
87
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
91
|
+
nk_size_t i = 0;
|
|
92
|
+
svuint32_t distance_sq_u32x = svdup_u32(0);
|
|
93
|
+
do {
|
|
94
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
|
|
95
|
+
svuint8_t a_u8x = svld1_u8(predicate_b8x, a + i);
|
|
96
|
+
svuint8_t b_u8x = svld1_u8(predicate_b8x, b + i);
|
|
97
|
+
svuint8_t diff_u8x = svabd_u8_x(predicate_b8x, a_u8x, b_u8x);
|
|
98
|
+
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
99
|
+
i += svcntb();
|
|
100
|
+
} while (i < n);
|
|
101
|
+
*result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
|
|
102
|
+
}
|
|
103
|
+
NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
104
|
+
nk_u32_t distance_sq_u32;
|
|
105
|
+
nk_sqeuclidean_u8_svesdot(a, b, n, &distance_sq_u32);
|
|
106
|
+
*result = nk_f32_sqrt_neon((nk_f32_t)distance_sq_u32);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
110
|
+
nk_size_t i = 0;
|
|
111
|
+
svuint32_t ab_u32x = svdup_u32(0);
|
|
112
|
+
svuint32_t a2_u32x = svdup_u32(0);
|
|
113
|
+
svuint32_t b2_u32x = svdup_u32(0);
|
|
114
|
+
do {
|
|
115
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
|
|
116
|
+
svuint8_t a_u8x = svld1_u8(predicate_b8x, a + i);
|
|
117
|
+
svuint8_t b_u8x = svld1_u8(predicate_b8x, b + i);
|
|
118
|
+
ab_u32x = svdot_u32(ab_u32x, a_u8x, b_u8x);
|
|
119
|
+
a2_u32x = svdot_u32(a2_u32x, a_u8x, a_u8x);
|
|
120
|
+
b2_u32x = svdot_u32(b2_u32x, b_u8x, b_u8x);
|
|
121
|
+
i += svcntb();
|
|
122
|
+
} while (i < n);
|
|
123
|
+
|
|
124
|
+
nk_u32_t ab = (nk_u32_t)svaddv_u32(svptrue_b32(), ab_u32x);
|
|
125
|
+
nk_u32_t a2 = (nk_u32_t)svaddv_u32(svptrue_b32(), a2_u32x);
|
|
126
|
+
nk_u32_t b2 = (nk_u32_t)svaddv_u32(svptrue_b32(), b2_u32x);
|
|
127
|
+
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
#if defined(__clang__)
|
|
131
|
+
#pragma clang attribute pop
|
|
132
|
+
#elif defined(__GNUC__)
|
|
133
|
+
#pragma GCC pop_options
|
|
134
|
+
#endif
|
|
135
|
+
|
|
136
|
+
#if defined(__cplusplus)
|
|
137
|
+
} // extern "C"
|
|
138
|
+
#endif
|
|
139
|
+
|
|
140
|
+
#endif // NK_TARGET_SVESDOT
|
|
141
|
+
#endif // NK_TARGET_ARM_
|
|
142
|
+
#endif // NK_SPATIAL_SVESDOT_H
|
|
@@ -64,7 +64,7 @@ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_v128relaxed_(nk_f64_t ab, nk_f64_t
|
|
|
64
64
|
return result > 0.0 ? result : 0.0;
|
|
65
65
|
}
|
|
66
66
|
|
|
67
|
-
#pragma region
|
|
67
|
+
#pragma region F32 and F64 Floats
|
|
68
68
|
|
|
69
69
|
NK_PUBLIC void nk_sqeuclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
70
70
|
v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
@@ -83,8 +83,8 @@ nk_sqeuclidean_f32_v128relaxed_cycle:
|
|
|
83
83
|
nk_load_b64_serial_(b_scalars, &b_f32_vec);
|
|
84
84
|
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
85
85
|
}
|
|
86
|
-
v128_t a_f32x2 =
|
|
87
|
-
v128_t b_f32x2 =
|
|
86
|
+
v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
|
|
87
|
+
v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
|
|
88
88
|
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
89
89
|
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
90
90
|
v128_t diff_f64x2 = wasm_f64x2_sub(a_f64x2, b_f64x2);
|
|
@@ -152,8 +152,8 @@ nk_angular_f32_v128relaxed_cycle:
|
|
|
152
152
|
}
|
|
153
153
|
|
|
154
154
|
// Upcast F32x2 → F64x2 for high-precision accumulation
|
|
155
|
-
v128_t a_f32x2 =
|
|
156
|
-
v128_t b_f32x2 =
|
|
155
|
+
v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
|
|
156
|
+
v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
|
|
157
157
|
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
158
158
|
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
159
159
|
|
|
@@ -203,8 +203,8 @@ nk_angular_f64_v128relaxed_cycle:
|
|
|
203
203
|
*result = nk_angular_normalize_f64_v128relaxed_(ab, a2, b2);
|
|
204
204
|
}
|
|
205
205
|
|
|
206
|
-
#pragma endregion
|
|
207
|
-
#pragma region
|
|
206
|
+
#pragma endregion F32 and F64 Floats
|
|
207
|
+
#pragma region F16 and BF16 Floats
|
|
208
208
|
|
|
209
209
|
NK_PUBLIC void nk_sqeuclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
210
210
|
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
@@ -286,31 +286,30 @@ nk_angular_f16_v128relaxed_cycle:
|
|
|
286
286
|
|
|
287
287
|
NK_PUBLIC void nk_sqeuclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
288
288
|
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
289
|
+
v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
|
|
289
290
|
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
290
291
|
nk_size_t count_scalars = n;
|
|
291
|
-
|
|
292
|
+
nk_b128_vec_t a_bf16_vec, b_bf16_vec;
|
|
292
293
|
|
|
293
294
|
nk_sqeuclidean_bf16_v128relaxed_cycle:
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
295
|
+
if (count_scalars < 8) {
|
|
296
|
+
nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
297
|
+
nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
298
298
|
count_scalars = 0;
|
|
299
299
|
}
|
|
300
300
|
else {
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
a_scalars +=
|
|
301
|
+
nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
|
|
302
|
+
nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
|
|
303
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
304
304
|
}
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
v128_t
|
|
312
|
-
sum_f32x4 = wasm_f32x4_relaxed_madd(
|
|
313
|
-
|
|
305
|
+
v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
|
|
306
|
+
v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
|
|
307
|
+
v128_t diff_even_f32x4 = wasm_f32x4_sub(a_even_f32x4, b_even_f32x4);
|
|
308
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_even_f32x4, diff_even_f32x4, sum_f32x4);
|
|
309
|
+
v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
|
|
310
|
+
v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
|
|
311
|
+
v128_t diff_odd_f32x4 = wasm_f32x4_sub(a_odd_f32x4, b_odd_f32x4);
|
|
312
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_odd_f32x4, diff_odd_f32x4, sum_f32x4);
|
|
314
313
|
if (count_scalars) goto nk_sqeuclidean_bf16_v128relaxed_cycle;
|
|
315
314
|
|
|
316
315
|
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
@@ -326,44 +325,297 @@ NK_PUBLIC void nk_angular_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *
|
|
|
326
325
|
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
327
326
|
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
328
327
|
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
328
|
+
v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
|
|
329
329
|
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
330
330
|
nk_size_t count_scalars = n;
|
|
331
|
-
|
|
331
|
+
nk_b128_vec_t a_bf16_vec, b_bf16_vec;
|
|
332
332
|
|
|
333
333
|
nk_angular_bf16_v128relaxed_cycle:
|
|
334
|
+
if (count_scalars < 8) {
|
|
335
|
+
nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
336
|
+
nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
337
|
+
count_scalars = 0;
|
|
338
|
+
}
|
|
339
|
+
else {
|
|
340
|
+
nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
|
|
341
|
+
nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
|
|
342
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
343
|
+
}
|
|
344
|
+
v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
|
|
345
|
+
v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
|
|
346
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, ab_f32x4);
|
|
347
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, a_even_f32x4, a2_f32x4);
|
|
348
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_even_f32x4, b_even_f32x4, b2_f32x4);
|
|
349
|
+
v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
|
|
350
|
+
v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
|
|
351
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, ab_f32x4);
|
|
352
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, a_odd_f32x4, a2_f32x4);
|
|
353
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_odd_f32x4, b_odd_f32x4, b2_f32x4);
|
|
354
|
+
if (count_scalars) goto nk_angular_bf16_v128relaxed_cycle;
|
|
355
|
+
|
|
356
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
357
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
358
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
359
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
#pragma endregion F16 and BF16 Floats
|
|
363
|
+
#pragma region FP8 Floats
|
|
364
|
+
|
|
365
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
366
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
367
|
+
nk_e4m3_t const *a_scalars = a, *b_scalars = b;
|
|
368
|
+
nk_size_t count_scalars = n;
|
|
369
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
370
|
+
|
|
371
|
+
nk_sqeuclidean_e4m3_v128relaxed_cycle:
|
|
334
372
|
if (count_scalars < 4) {
|
|
335
|
-
|
|
336
|
-
|
|
373
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
374
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
337
375
|
count_scalars = 0;
|
|
338
376
|
}
|
|
339
377
|
else {
|
|
340
|
-
|
|
341
|
-
|
|
378
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
379
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
342
380
|
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
343
381
|
}
|
|
382
|
+
nk_b128_vec_t a_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(a_raw);
|
|
383
|
+
nk_b128_vec_t b_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(b_raw);
|
|
384
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
385
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
386
|
+
if (count_scalars) goto nk_sqeuclidean_e4m3_v128relaxed_cycle;
|
|
344
387
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
|
|
388
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
389
|
+
}
|
|
348
390
|
|
|
349
|
-
|
|
391
|
+
NK_PUBLIC void nk_euclidean_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
392
|
+
nk_sqeuclidean_e4m3_v128relaxed(a, b, n, result);
|
|
393
|
+
*result = nk_f32_sqrt_v128relaxed(*result);
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
NK_PUBLIC void nk_angular_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
397
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
398
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
399
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
400
|
+
nk_e4m3_t const *a_scalars = a, *b_scalars = b;
|
|
401
|
+
nk_size_t count_scalars = n;
|
|
402
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
403
|
+
|
|
404
|
+
nk_angular_e4m3_v128relaxed_cycle:
|
|
405
|
+
if (count_scalars < 4) {
|
|
406
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
407
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
408
|
+
count_scalars = 0;
|
|
409
|
+
}
|
|
410
|
+
else {
|
|
411
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
412
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
413
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
414
|
+
}
|
|
415
|
+
nk_b128_vec_t a_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(a_raw);
|
|
416
|
+
nk_b128_vec_t b_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(b_raw);
|
|
350
417
|
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
351
418
|
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
352
419
|
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
420
|
+
if (count_scalars) goto nk_angular_e4m3_v128relaxed_cycle;
|
|
353
421
|
|
|
354
|
-
|
|
422
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
423
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
424
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
425
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
429
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
430
|
+
nk_e5m2_t const *a_scalars = a, *b_scalars = b;
|
|
431
|
+
nk_size_t count_scalars = n;
|
|
432
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
433
|
+
|
|
434
|
+
nk_sqeuclidean_e5m2_v128relaxed_cycle:
|
|
435
|
+
if (count_scalars < 4) {
|
|
436
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
437
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
438
|
+
count_scalars = 0;
|
|
439
|
+
}
|
|
440
|
+
else {
|
|
441
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
442
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
443
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
444
|
+
}
|
|
445
|
+
nk_b128_vec_t a_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(a_raw);
|
|
446
|
+
nk_b128_vec_t b_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(b_raw);
|
|
447
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
448
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
449
|
+
if (count_scalars) goto nk_sqeuclidean_e5m2_v128relaxed_cycle;
|
|
450
|
+
|
|
451
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
NK_PUBLIC void nk_euclidean_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
455
|
+
nk_sqeuclidean_e5m2_v128relaxed(a, b, n, result);
|
|
456
|
+
*result = nk_f32_sqrt_v128relaxed(*result);
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
NK_PUBLIC void nk_angular_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
460
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
461
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
462
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
463
|
+
nk_e5m2_t const *a_scalars = a, *b_scalars = b;
|
|
464
|
+
nk_size_t count_scalars = n;
|
|
465
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
466
|
+
|
|
467
|
+
nk_angular_e5m2_v128relaxed_cycle:
|
|
468
|
+
if (count_scalars < 4) {
|
|
469
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
470
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
471
|
+
count_scalars = 0;
|
|
472
|
+
}
|
|
473
|
+
else {
|
|
474
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
475
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
476
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
477
|
+
}
|
|
478
|
+
nk_b128_vec_t a_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(a_raw);
|
|
479
|
+
nk_b128_vec_t b_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(b_raw);
|
|
480
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
481
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
482
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
483
|
+
if (count_scalars) goto nk_angular_e5m2_v128relaxed_cycle;
|
|
355
484
|
|
|
356
|
-
// Reduce accumulators
|
|
357
485
|
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
358
486
|
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
359
487
|
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
488
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
489
|
+
}
|
|
360
490
|
|
|
361
|
-
|
|
491
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
492
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
493
|
+
nk_e2m3_t const *a_scalars = a, *b_scalars = b;
|
|
494
|
+
nk_size_t count_scalars = n;
|
|
495
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
496
|
+
|
|
497
|
+
nk_sqeuclidean_e2m3_v128relaxed_cycle:
|
|
498
|
+
if (count_scalars < 4) {
|
|
499
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
500
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
501
|
+
count_scalars = 0;
|
|
502
|
+
}
|
|
503
|
+
else {
|
|
504
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
505
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
506
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
507
|
+
}
|
|
508
|
+
nk_b128_vec_t a_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(a_raw);
|
|
509
|
+
nk_b128_vec_t b_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(b_raw);
|
|
510
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
511
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
512
|
+
if (count_scalars) goto nk_sqeuclidean_e2m3_v128relaxed_cycle;
|
|
513
|
+
|
|
514
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
NK_PUBLIC void nk_euclidean_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
518
|
+
nk_sqeuclidean_e2m3_v128relaxed(a, b, n, result);
|
|
519
|
+
*result = nk_f32_sqrt_v128relaxed(*result);
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
NK_PUBLIC void nk_angular_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
523
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
524
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
525
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
526
|
+
nk_e2m3_t const *a_scalars = a, *b_scalars = b;
|
|
527
|
+
nk_size_t count_scalars = n;
|
|
528
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
529
|
+
|
|
530
|
+
nk_angular_e2m3_v128relaxed_cycle:
|
|
531
|
+
if (count_scalars < 4) {
|
|
532
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
533
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
534
|
+
count_scalars = 0;
|
|
535
|
+
}
|
|
536
|
+
else {
|
|
537
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
538
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
539
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
540
|
+
}
|
|
541
|
+
nk_b128_vec_t a_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(a_raw);
|
|
542
|
+
nk_b128_vec_t b_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(b_raw);
|
|
543
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
544
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
545
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
546
|
+
if (count_scalars) goto nk_angular_e2m3_v128relaxed_cycle;
|
|
547
|
+
|
|
548
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
549
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
550
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
551
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
555
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
556
|
+
nk_e3m2_t const *a_scalars = a, *b_scalars = b;
|
|
557
|
+
nk_size_t count_scalars = n;
|
|
558
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
559
|
+
|
|
560
|
+
nk_sqeuclidean_e3m2_v128relaxed_cycle:
|
|
561
|
+
if (count_scalars < 4) {
|
|
562
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
563
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
564
|
+
count_scalars = 0;
|
|
565
|
+
}
|
|
566
|
+
else {
|
|
567
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
568
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
569
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
570
|
+
}
|
|
571
|
+
nk_b128_vec_t a_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(a_raw);
|
|
572
|
+
nk_b128_vec_t b_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(b_raw);
|
|
573
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
574
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
575
|
+
if (count_scalars) goto nk_sqeuclidean_e3m2_v128relaxed_cycle;
|
|
576
|
+
|
|
577
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
NK_PUBLIC void nk_euclidean_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
581
|
+
nk_sqeuclidean_e3m2_v128relaxed(a, b, n, result);
|
|
582
|
+
*result = nk_f32_sqrt_v128relaxed(*result);
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
NK_PUBLIC void nk_angular_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
586
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
587
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
588
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
589
|
+
nk_e3m2_t const *a_scalars = a, *b_scalars = b;
|
|
590
|
+
nk_size_t count_scalars = n;
|
|
591
|
+
nk_b32_vec_t a_raw, b_raw;
|
|
592
|
+
|
|
593
|
+
nk_angular_e3m2_v128relaxed_cycle:
|
|
594
|
+
if (count_scalars < 4) {
|
|
595
|
+
a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
|
|
596
|
+
b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
|
|
597
|
+
count_scalars = 0;
|
|
598
|
+
}
|
|
599
|
+
else {
|
|
600
|
+
nk_load_b32_serial_(a_scalars, &a_raw);
|
|
601
|
+
nk_load_b32_serial_(b_scalars, &b_raw);
|
|
602
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
603
|
+
}
|
|
604
|
+
nk_b128_vec_t a_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(a_raw);
|
|
605
|
+
nk_b128_vec_t b_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(b_raw);
|
|
606
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
607
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
608
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
609
|
+
if (count_scalars) goto nk_angular_e3m2_v128relaxed_cycle;
|
|
610
|
+
|
|
611
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
612
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
613
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
362
614
|
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
363
615
|
}
|
|
364
616
|
|
|
365
|
-
#pragma endregion
|
|
366
|
-
#pragma region
|
|
617
|
+
#pragma endregion FP8 Floats
|
|
618
|
+
#pragma region Spatial From Dot Helpers
|
|
367
619
|
|
|
368
620
|
/** @brief Angular from_dot: computes 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs in f32. */
|
|
369
621
|
NK_INTERNAL void nk_angular_through_f32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
@@ -437,8 +689,8 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_v128relaxed_(nk_b128_vec_t do
|
|
|
437
689
|
results->v128 = wasm_f32x4_sqrt(dist_sq_f32x4);
|
|
438
690
|
}
|
|
439
691
|
|
|
440
|
-
#pragma endregion
|
|
441
|
-
#pragma region
|
|
692
|
+
#pragma endregion Spatial From Dot Helpers
|
|
693
|
+
#pragma region I8 and U8 Integers
|
|
442
694
|
|
|
443
695
|
NK_PUBLIC void nk_sqeuclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
444
696
|
v128_t sum_u32x4 = wasm_u32x4_splat(0);
|
|
@@ -703,7 +955,7 @@ NK_PUBLIC void nk_angular_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_
|
|
|
703
955
|
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_(dot_ab, norm_aa, norm_bb);
|
|
704
956
|
}
|
|
705
957
|
|
|
706
|
-
#pragma endregion
|
|
958
|
+
#pragma endregion I8 and U8 Integers
|
|
707
959
|
|
|
708
960
|
#if defined(__clang__)
|
|
709
961
|
#pragma clang attribute pop
|