numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,717 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for WASM.
|
|
3
|
+
* @file include/numkong/spatial/v128relaxed.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 2, 2026
|
|
6
|
+
*
|
|
7
|
+
* Contains:
|
|
8
|
+
* - Euclidean (L2) distance
|
|
9
|
+
* - Squared Euclidean (L2SQ) distance
|
|
10
|
+
* - Angular distance (1 - cosine similarity)
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
* - 64-bit IEEE floating point (f64)
|
|
14
|
+
* - 32-bit IEEE floating point (f32)
|
|
15
|
+
* - 16-bit IEEE floating point (f16)
|
|
16
|
+
* - 16-bit brain floating point (bf16)
|
|
17
|
+
*
|
|
18
|
+
* Key improvements:
|
|
19
|
+
* - F32→F64 upcast for angular_f32 (matches Haswell/NEON precision strategy)
|
|
20
|
+
* - Parallel SIMD sqrt for normalization (computes both sqrts simultaneously)
|
|
21
|
+
* - Edge case handling (zero vectors, numerical stability)
|
|
22
|
+
* - Uses relaxed FMA for optimal throughput
|
|
23
|
+
*
|
|
24
|
+
* @see For pattern references:
|
|
25
|
+
* - Haswell: include/numkong/spatial/haswell.h
|
|
26
|
+
* - NEON: include/numkong/spatial/neon.h
|
|
27
|
+
*/
|
|
28
|
+
|
|
29
|
+
#ifndef NK_SPATIAL_V128RELAXED_H
|
|
30
|
+
#define NK_SPATIAL_V128RELAXED_H
|
|
31
|
+
|
|
32
|
+
#if NK_TARGET_V128RELAXED
|
|
33
|
+
|
|
34
|
+
#include "numkong/types.h"
|
|
35
|
+
#include "numkong/scalar/v128relaxed.h" // `nk_f32_sqrt_v128relaxed`
|
|
36
|
+
#include "numkong/reduce/v128relaxed.h"
|
|
37
|
+
#include "numkong/cast/serial.h"
|
|
38
|
+
#include "numkong/cast/v128relaxed.h"
|
|
39
|
+
|
|
40
|
+
#if defined(__cplusplus)
|
|
41
|
+
extern "C" {
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
#if defined(__clang__)
|
|
45
|
+
#pragma clang attribute push(__attribute__((target("relaxed-simd"))), apply_to = function)
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
NK_INTERNAL nk_f64_t nk_angular_normalize_f64_v128relaxed_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
|
|
49
|
+
// Edge case: both vectors have zero magnitude
|
|
50
|
+
if (a2 == 0.0 && b2 == 0.0) return 0.0;
|
|
51
|
+
// Edge case: dot product is zero (perpendicular or one vector is zero)
|
|
52
|
+
if (ab == 0.0) return 1.0;
|
|
53
|
+
|
|
54
|
+
// Compute both square roots in parallel using SIMD (more efficient than 2 scalar sqrts)
|
|
55
|
+
v128_t squares_f64x2 = wasm_f64x2_make(a2, b2);
|
|
56
|
+
v128_t sqrts_f64x2 = wasm_f64x2_sqrt(squares_f64x2);
|
|
57
|
+
nk_f64_t a_sqrt = wasm_f64x2_extract_lane(sqrts_f64x2, 0);
|
|
58
|
+
nk_f64_t b_sqrt = wasm_f64x2_extract_lane(sqrts_f64x2, 1);
|
|
59
|
+
|
|
60
|
+
// Compute angular distance: 1 - cosine_similarity
|
|
61
|
+
nk_f64_t result = 1.0 - ab / (a_sqrt * b_sqrt);
|
|
62
|
+
|
|
63
|
+
// Clamp negative results to 0 (can occur due to floating-point rounding)
|
|
64
|
+
return result > 0.0 ? result : 0.0;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
#pragma region - Traditional Floats
|
|
68
|
+
|
|
69
|
+
NK_PUBLIC void nk_sqeuclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
70
|
+
v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
71
|
+
nk_f32_t const *a_scalars = a, *b_scalars = b;
|
|
72
|
+
nk_size_t count_scalars = n;
|
|
73
|
+
nk_b64_vec_t a_f32_vec, b_f32_vec;
|
|
74
|
+
|
|
75
|
+
nk_sqeuclidean_f32_v128relaxed_cycle:
|
|
76
|
+
if (count_scalars < 2) {
|
|
77
|
+
nk_partial_load_b32x2_serial_(a_scalars, &a_f32_vec, count_scalars);
|
|
78
|
+
nk_partial_load_b32x2_serial_(b_scalars, &b_f32_vec, count_scalars);
|
|
79
|
+
count_scalars = 0;
|
|
80
|
+
}
|
|
81
|
+
else {
|
|
82
|
+
nk_load_b64_serial_(a_scalars, &a_f32_vec);
|
|
83
|
+
nk_load_b64_serial_(b_scalars, &b_f32_vec);
|
|
84
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
85
|
+
}
|
|
86
|
+
v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
|
|
87
|
+
v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
|
|
88
|
+
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
89
|
+
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
90
|
+
v128_t diff_f64x2 = wasm_f64x2_sub(a_f64x2, b_f64x2);
|
|
91
|
+
sum_f64x2 = wasm_f64x2_relaxed_madd(diff_f64x2, diff_f64x2, sum_f64x2);
|
|
92
|
+
if (count_scalars) goto nk_sqeuclidean_f32_v128relaxed_cycle;
|
|
93
|
+
|
|
94
|
+
*result = nk_reduce_add_f64x2_v128relaxed_(sum_f64x2);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
NK_PUBLIC void nk_sqeuclidean_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
98
|
+
v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
99
|
+
nk_f64_t const *a_scalars = a, *b_scalars = b;
|
|
100
|
+
nk_size_t count_scalars = n;
|
|
101
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
102
|
+
|
|
103
|
+
nk_sqeuclidean_f64_v128relaxed_cycle:
|
|
104
|
+
if (count_scalars < 2) {
|
|
105
|
+
nk_partial_load_b64x2_serial_(a_scalars, &a_vec, count_scalars);
|
|
106
|
+
nk_partial_load_b64x2_serial_(b_scalars, &b_vec, count_scalars);
|
|
107
|
+
count_scalars = 0;
|
|
108
|
+
}
|
|
109
|
+
else {
|
|
110
|
+
nk_load_b128_v128relaxed_(a_scalars, &a_vec);
|
|
111
|
+
nk_load_b128_v128relaxed_(b_scalars, &b_vec);
|
|
112
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
113
|
+
}
|
|
114
|
+
v128_t diff_f64x2 = wasm_f64x2_sub(a_vec.v128, b_vec.v128);
|
|
115
|
+
sum_f64x2 = wasm_f64x2_relaxed_madd(diff_f64x2, diff_f64x2, sum_f64x2);
|
|
116
|
+
if (count_scalars) goto nk_sqeuclidean_f64_v128relaxed_cycle;
|
|
117
|
+
|
|
118
|
+
*result = nk_reduce_add_f64x2_v128relaxed_(sum_f64x2);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
NK_PUBLIC void nk_euclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
122
|
+
nk_f64_t l2sq;
|
|
123
|
+
nk_sqeuclidean_f32_v128relaxed(a, b, n, &l2sq);
|
|
124
|
+
*result = nk_f64_sqrt_v128relaxed(l2sq);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
NK_PUBLIC void nk_euclidean_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
128
|
+
nk_f64_t l2sq;
|
|
129
|
+
nk_sqeuclidean_f64_v128relaxed(a, b, n, &l2sq);
|
|
130
|
+
*result = nk_f64_sqrt_v128relaxed(l2sq);
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
NK_PUBLIC void nk_angular_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
134
|
+
// F32 → F64 upcast for numerical stability
|
|
135
|
+
v128_t ab_f64x2 = wasm_f64x2_splat(0.0);
|
|
136
|
+
v128_t a2_f64x2 = wasm_f64x2_splat(0.0);
|
|
137
|
+
v128_t b2_f64x2 = wasm_f64x2_splat(0.0);
|
|
138
|
+
nk_f32_t const *a_scalars = a, *b_scalars = b;
|
|
139
|
+
nk_size_t count_scalars = n;
|
|
140
|
+
nk_b64_vec_t a_f32_vec, b_f32_vec;
|
|
141
|
+
|
|
142
|
+
nk_angular_f32_v128relaxed_cycle:
|
|
143
|
+
if (count_scalars < 2) {
|
|
144
|
+
nk_partial_load_b32x2_serial_(a_scalars, &a_f32_vec, count_scalars);
|
|
145
|
+
nk_partial_load_b32x2_serial_(b_scalars, &b_f32_vec, count_scalars);
|
|
146
|
+
count_scalars = 0;
|
|
147
|
+
}
|
|
148
|
+
else {
|
|
149
|
+
nk_load_b64_serial_(a_scalars, &a_f32_vec);
|
|
150
|
+
nk_load_b64_serial_(b_scalars, &b_f32_vec);
|
|
151
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// Upcast F32x2 → F64x2 for high-precision accumulation
|
|
155
|
+
v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
|
|
156
|
+
v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
|
|
157
|
+
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
158
|
+
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
159
|
+
|
|
160
|
+
// Accumulate: ab += a·b, a2 += a·a, b2 += b·b
|
|
161
|
+
ab_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, ab_f64x2);
|
|
162
|
+
a2_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, a_f64x2, a2_f64x2);
|
|
163
|
+
b2_f64x2 = wasm_f64x2_relaxed_madd(b_f64x2, b_f64x2, b2_f64x2);
|
|
164
|
+
if (count_scalars) goto nk_angular_f32_v128relaxed_cycle;
|
|
165
|
+
|
|
166
|
+
// Reduce and normalize using F64 arithmetic
|
|
167
|
+
nk_f64_t ab_f64 = nk_reduce_add_f64x2_v128relaxed_(ab_f64x2);
|
|
168
|
+
nk_f64_t a2_f64 = nk_reduce_add_f64x2_v128relaxed_(a2_f64x2);
|
|
169
|
+
nk_f64_t b2_f64 = nk_reduce_add_f64x2_v128relaxed_(b2_f64x2);
|
|
170
|
+
*result = nk_angular_normalize_f64_v128relaxed_(ab_f64, a2_f64, b2_f64);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
NK_PUBLIC void nk_angular_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
174
|
+
v128_t ab_f64x2 = wasm_f64x2_splat(0.0);
|
|
175
|
+
v128_t a2_f64x2 = wasm_f64x2_splat(0.0);
|
|
176
|
+
v128_t b2_f64x2 = wasm_f64x2_splat(0.0);
|
|
177
|
+
nk_f64_t const *a_scalars = a, *b_scalars = b;
|
|
178
|
+
nk_size_t count_scalars = n;
|
|
179
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
180
|
+
|
|
181
|
+
nk_angular_f64_v128relaxed_cycle:
|
|
182
|
+
if (count_scalars < 2) {
|
|
183
|
+
nk_partial_load_b64x2_serial_(a_scalars, &a_vec, count_scalars);
|
|
184
|
+
nk_partial_load_b64x2_serial_(b_scalars, &b_vec, count_scalars);
|
|
185
|
+
count_scalars = 0;
|
|
186
|
+
}
|
|
187
|
+
else {
|
|
188
|
+
nk_load_b128_v128relaxed_(a_scalars, &a_vec);
|
|
189
|
+
nk_load_b128_v128relaxed_(b_scalars, &b_vec);
|
|
190
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
// Accumulate: ab += a·b, a2 += a·a, b2 += b·b
|
|
194
|
+
ab_f64x2 = wasm_f64x2_relaxed_madd(a_vec.v128, b_vec.v128, ab_f64x2);
|
|
195
|
+
a2_f64x2 = wasm_f64x2_relaxed_madd(a_vec.v128, a_vec.v128, a2_f64x2);
|
|
196
|
+
b2_f64x2 = wasm_f64x2_relaxed_madd(b_vec.v128, b_vec.v128, b2_f64x2);
|
|
197
|
+
if (count_scalars) goto nk_angular_f64_v128relaxed_cycle;
|
|
198
|
+
|
|
199
|
+
// Reduce and normalize
|
|
200
|
+
nk_f64_t ab = nk_reduce_add_f64x2_v128relaxed_(ab_f64x2);
|
|
201
|
+
nk_f64_t a2 = nk_reduce_add_f64x2_v128relaxed_(a2_f64x2);
|
|
202
|
+
nk_f64_t b2 = nk_reduce_add_f64x2_v128relaxed_(b2_f64x2);
|
|
203
|
+
*result = nk_angular_normalize_f64_v128relaxed_(ab, a2, b2);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
#pragma endregion - Traditional Floats
|
|
207
|
+
#pragma region - Smaller Floats
|
|
208
|
+
|
|
209
|
+
NK_PUBLIC void nk_sqeuclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
210
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
211
|
+
nk_f16_t const *a_scalars = a, *b_scalars = b;
|
|
212
|
+
nk_size_t count_scalars = n;
|
|
213
|
+
nk_b64_vec_t a_f16_vec, b_f16_vec;
|
|
214
|
+
|
|
215
|
+
nk_sqeuclidean_f16_v128relaxed_cycle:
|
|
216
|
+
// Tail or full load
|
|
217
|
+
if (count_scalars < 4) {
|
|
218
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_f16_vec, count_scalars);
|
|
219
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_f16_vec, count_scalars);
|
|
220
|
+
count_scalars = 0;
|
|
221
|
+
}
|
|
222
|
+
else {
|
|
223
|
+
nk_load_b64_serial_(a_scalars, &a_f16_vec);
|
|
224
|
+
nk_load_b64_serial_(b_scalars, &b_f16_vec);
|
|
225
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
// Convert f16 → f32 (4 elements)
|
|
229
|
+
nk_b128_vec_t a_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(a_f16_vec);
|
|
230
|
+
nk_b128_vec_t b_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(b_f16_vec);
|
|
231
|
+
|
|
232
|
+
// Accumulate (a - b)²
|
|
233
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
234
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
235
|
+
|
|
236
|
+
if (count_scalars) goto nk_sqeuclidean_f16_v128relaxed_cycle;
|
|
237
|
+
|
|
238
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
NK_PUBLIC void nk_euclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
242
|
+
nk_f32_t l2sq;
|
|
243
|
+
nk_sqeuclidean_f16_v128relaxed(a, b, n, &l2sq);
|
|
244
|
+
*result = nk_f32_sqrt_v128relaxed(l2sq);
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
NK_PUBLIC void nk_angular_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
248
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
249
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
250
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
251
|
+
nk_f16_t const *a_scalars = a, *b_scalars = b;
|
|
252
|
+
nk_size_t count_scalars = n;
|
|
253
|
+
nk_b64_vec_t a_f16_vec, b_f16_vec;
|
|
254
|
+
|
|
255
|
+
nk_angular_f16_v128relaxed_cycle:
|
|
256
|
+
if (count_scalars < 4) {
|
|
257
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_f16_vec, count_scalars);
|
|
258
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_f16_vec, count_scalars);
|
|
259
|
+
count_scalars = 0;
|
|
260
|
+
}
|
|
261
|
+
else {
|
|
262
|
+
nk_load_b64_serial_(a_scalars, &a_f16_vec);
|
|
263
|
+
nk_load_b64_serial_(b_scalars, &b_f16_vec);
|
|
264
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
// Convert f16 → f32
|
|
268
|
+
nk_b128_vec_t a_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(a_f16_vec);
|
|
269
|
+
nk_b128_vec_t b_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(b_f16_vec);
|
|
270
|
+
|
|
271
|
+
// Triple accumulation: ab, a², b²
|
|
272
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
273
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
274
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
275
|
+
|
|
276
|
+
if (count_scalars) goto nk_angular_f16_v128relaxed_cycle;
|
|
277
|
+
|
|
278
|
+
// Reduce accumulators
|
|
279
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
280
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
281
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
282
|
+
|
|
283
|
+
// Normalize using f64 helper (handles edge cases: zero vectors, perpendicular, clamping)
|
|
284
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
288
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
289
|
+
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
290
|
+
nk_size_t count_scalars = n;
|
|
291
|
+
nk_b64_vec_t a_bf16_vec, b_bf16_vec;
|
|
292
|
+
|
|
293
|
+
nk_sqeuclidean_bf16_v128relaxed_cycle:
|
|
294
|
+
// Tail or full load
|
|
295
|
+
if (count_scalars < 4) {
|
|
296
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
297
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
298
|
+
count_scalars = 0;
|
|
299
|
+
}
|
|
300
|
+
else {
|
|
301
|
+
nk_load_b64_serial_(a_scalars, &a_bf16_vec);
|
|
302
|
+
nk_load_b64_serial_(b_scalars, &b_bf16_vec);
|
|
303
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
// Convert bf16 → f32 (4 elements)
|
|
307
|
+
nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
|
|
308
|
+
nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
|
|
309
|
+
|
|
310
|
+
// Accumulate (a - b)²
|
|
311
|
+
v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
|
|
312
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
|
|
313
|
+
|
|
314
|
+
if (count_scalars) goto nk_sqeuclidean_bf16_v128relaxed_cycle;
|
|
315
|
+
|
|
316
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
NK_PUBLIC void nk_euclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
320
|
+
nk_f32_t l2sq;
|
|
321
|
+
nk_sqeuclidean_bf16_v128relaxed(a, b, n, &l2sq);
|
|
322
|
+
*result = nk_f32_sqrt_v128relaxed(l2sq);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
NK_PUBLIC void nk_angular_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
326
|
+
v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
|
|
327
|
+
v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
328
|
+
v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
|
|
329
|
+
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
330
|
+
nk_size_t count_scalars = n;
|
|
331
|
+
nk_b64_vec_t a_bf16_vec, b_bf16_vec;
|
|
332
|
+
|
|
333
|
+
nk_angular_bf16_v128relaxed_cycle:
|
|
334
|
+
if (count_scalars < 4) {
|
|
335
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
336
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
337
|
+
count_scalars = 0;
|
|
338
|
+
}
|
|
339
|
+
else {
|
|
340
|
+
nk_load_b64_serial_(a_scalars, &a_bf16_vec);
|
|
341
|
+
nk_load_b64_serial_(b_scalars, &b_bf16_vec);
|
|
342
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
// Convert bf16 → f32
|
|
346
|
+
nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
|
|
347
|
+
nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
|
|
348
|
+
|
|
349
|
+
// Triple accumulation: ab, a², b²
|
|
350
|
+
ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
|
|
351
|
+
a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
|
|
352
|
+
b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
|
|
353
|
+
|
|
354
|
+
if (count_scalars) goto nk_angular_bf16_v128relaxed_cycle;
|
|
355
|
+
|
|
356
|
+
// Reduce accumulators
|
|
357
|
+
nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
|
|
358
|
+
nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
|
|
359
|
+
nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
|
|
360
|
+
|
|
361
|
+
// Normalize using f64 helper (handles edge cases: zero vectors, perpendicular, clamping)
|
|
362
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
#pragma endregion - Smaller Floats
|
|
366
|
+
#pragma region - Spatial From-Dot Helpers
|
|
367
|
+
|
|
368
|
+
/** @brief Angular from_dot: computes 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs in f32. */
|
|
369
|
+
NK_INTERNAL void nk_angular_through_f32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
370
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
371
|
+
v128_t dots_f32x4 = dots.v128;
|
|
372
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat(query_sumsq);
|
|
373
|
+
v128_t products_f32x4 = wasm_f32x4_mul(query_sumsq_f32x4, target_sumsqs.v128);
|
|
374
|
+
v128_t sqrt_products_f32x4 = wasm_f32x4_sqrt(products_f32x4);
|
|
375
|
+
v128_t normalized_f32x4 = wasm_f32x4_div(dots_f32x4, sqrt_products_f32x4);
|
|
376
|
+
v128_t angular_f32x4 = wasm_f32x4_sub(wasm_f32x4_splat(1.0f), normalized_f32x4);
|
|
377
|
+
results->v128 = wasm_f32x4_max(angular_f32x4, wasm_f32x4_splat(0.0f));
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs in f32. */
|
|
381
|
+
NK_INTERNAL void nk_euclidean_through_f32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
382
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
383
|
+
v128_t dots_f32x4 = dots.v128;
|
|
384
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat(query_sumsq);
|
|
385
|
+
v128_t two_f32x4 = wasm_f32x4_splat(2.0f);
|
|
386
|
+
v128_t sum_sq_f32x4 = wasm_f32x4_add(query_sumsq_f32x4, target_sumsqs.v128);
|
|
387
|
+
v128_t dist_sq_f32x4 = wasm_f32x4_relaxed_nmadd(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
388
|
+
dist_sq_f32x4 = wasm_f32x4_max(dist_sq_f32x4, wasm_f32x4_splat(0.0f));
|
|
389
|
+
results->v128 = wasm_f32x4_sqrt(dist_sq_f32x4);
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
/** @brief Angular from_dot for i32 accumulators: cast to f32, then angular normalization. 4 pairs. */
|
|
393
|
+
NK_INTERNAL void nk_angular_through_i32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
394
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
395
|
+
v128_t dots_f32x4 = wasm_f32x4_convert_i32x4(dots.v128);
|
|
396
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat((nk_f32_t)query_sumsq);
|
|
397
|
+
v128_t products_f32x4 = wasm_f32x4_mul(query_sumsq_f32x4, wasm_f32x4_convert_i32x4(target_sumsqs.v128));
|
|
398
|
+
v128_t sqrt_products_f32x4 = wasm_f32x4_sqrt(products_f32x4);
|
|
399
|
+
v128_t normalized_f32x4 = wasm_f32x4_div(dots_f32x4, sqrt_products_f32x4);
|
|
400
|
+
v128_t angular_f32x4 = wasm_f32x4_sub(wasm_f32x4_splat(1.0f), normalized_f32x4);
|
|
401
|
+
results->v128 = wasm_f32x4_max(angular_f32x4, wasm_f32x4_splat(0.0f));
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
/** @brief Euclidean from_dot for i32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
405
|
+
NK_INTERNAL void nk_euclidean_through_i32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
406
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
407
|
+
v128_t dots_f32x4 = wasm_f32x4_convert_i32x4(dots.v128);
|
|
408
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat((nk_f32_t)query_sumsq);
|
|
409
|
+
v128_t two_f32x4 = wasm_f32x4_splat(2.0f);
|
|
410
|
+
v128_t sum_sq_f32x4 = wasm_f32x4_add(query_sumsq_f32x4, wasm_f32x4_convert_i32x4(target_sumsqs.v128));
|
|
411
|
+
v128_t dist_sq_f32x4 = wasm_f32x4_relaxed_nmadd(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
412
|
+
dist_sq_f32x4 = wasm_f32x4_max(dist_sq_f32x4, wasm_f32x4_splat(0.0f));
|
|
413
|
+
results->v128 = wasm_f32x4_sqrt(dist_sq_f32x4);
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
/** @brief Angular from_dot for u32 accumulators: cast to f32, then angular normalization. 4 pairs. */
|
|
417
|
+
NK_INTERNAL void nk_angular_through_u32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
418
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
419
|
+
v128_t dots_f32x4 = wasm_f32x4_convert_u32x4(dots.v128);
|
|
420
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat((nk_f32_t)query_sumsq);
|
|
421
|
+
v128_t products_f32x4 = wasm_f32x4_mul(query_sumsq_f32x4, wasm_f32x4_convert_u32x4(target_sumsqs.v128));
|
|
422
|
+
v128_t sqrt_products_f32x4 = wasm_f32x4_sqrt(products_f32x4);
|
|
423
|
+
v128_t normalized_f32x4 = wasm_f32x4_div(dots_f32x4, sqrt_products_f32x4);
|
|
424
|
+
v128_t angular_f32x4 = wasm_f32x4_sub(wasm_f32x4_splat(1.0f), normalized_f32x4);
|
|
425
|
+
results->v128 = wasm_f32x4_max(angular_f32x4, wasm_f32x4_splat(0.0f));
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
/** @brief Euclidean from_dot for u32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
|
|
429
|
+
NK_INTERNAL void nk_euclidean_through_u32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
430
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
431
|
+
v128_t dots_f32x4 = wasm_f32x4_convert_u32x4(dots.v128);
|
|
432
|
+
v128_t query_sumsq_f32x4 = wasm_f32x4_splat((nk_f32_t)query_sumsq);
|
|
433
|
+
v128_t two_f32x4 = wasm_f32x4_splat(2.0f);
|
|
434
|
+
v128_t sum_sq_f32x4 = wasm_f32x4_add(query_sumsq_f32x4, wasm_f32x4_convert_u32x4(target_sumsqs.v128));
|
|
435
|
+
v128_t dist_sq_f32x4 = wasm_f32x4_relaxed_nmadd(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
436
|
+
dist_sq_f32x4 = wasm_f32x4_max(dist_sq_f32x4, wasm_f32x4_splat(0.0f));
|
|
437
|
+
results->v128 = wasm_f32x4_sqrt(dist_sq_f32x4);
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
#pragma endregion - Spatial From - Dot Helpers
|
|
441
|
+
#pragma region - Integer Spatial
|
|
442
|
+
|
|
443
|
+
NK_PUBLIC void nk_sqeuclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
444
|
+
v128_t sum_u32x4 = wasm_u32x4_splat(0);
|
|
445
|
+
nk_u8_t const *a_scalars = a, *b_scalars = b;
|
|
446
|
+
nk_size_t count_scalars = n;
|
|
447
|
+
v128_t a_u8x16, b_u8x16;
|
|
448
|
+
|
|
449
|
+
nk_sqeuclidean_u8_v128relaxed_cycle:
|
|
450
|
+
if (count_scalars < 16) {
|
|
451
|
+
nk_b128_vec_t a_vec = {0}, b_vec = {0};
|
|
452
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
453
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
454
|
+
a_u8x16 = a_vec.v128;
|
|
455
|
+
b_u8x16 = b_vec.v128;
|
|
456
|
+
count_scalars = 0;
|
|
457
|
+
}
|
|
458
|
+
else {
|
|
459
|
+
a_u8x16 = wasm_v128_load(a_scalars);
|
|
460
|
+
b_u8x16 = wasm_v128_load(b_scalars);
|
|
461
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
// |a-b| via saturating subtraction: diff = (a ⊖ b) | (b ⊖ a)
|
|
465
|
+
v128_t difference_u8x16 = wasm_v128_or(wasm_u8x16_sub_sat(a_u8x16, b_u8x16), wasm_u8x16_sub_sat(b_u8x16, a_u8x16));
|
|
466
|
+
|
|
467
|
+
// Widen to u16 and square via extmul
|
|
468
|
+
v128_t difference_low_u16x8 = wasm_u16x8_extend_low_u8x16(difference_u8x16);
|
|
469
|
+
v128_t difference_high_u16x8 = wasm_u16x8_extend_high_u8x16(difference_u8x16);
|
|
470
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_low_i16x8(difference_low_u16x8, difference_low_u16x8));
|
|
471
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_high_i16x8(difference_low_u16x8, difference_low_u16x8));
|
|
472
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_low_i16x8(difference_high_u16x8, difference_high_u16x8));
|
|
473
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_high_i16x8(difference_high_u16x8, difference_high_u16x8));
|
|
474
|
+
if (count_scalars) goto nk_sqeuclidean_u8_v128relaxed_cycle;
|
|
475
|
+
|
|
476
|
+
*result = nk_reduce_add_u32x4_v128relaxed_(sum_u32x4);
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
NK_PUBLIC void nk_euclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
480
|
+
nk_u32_t distance_squared;
|
|
481
|
+
nk_sqeuclidean_u8_v128relaxed(a, b, n, &distance_squared);
|
|
482
|
+
*result = nk_f32_sqrt_v128relaxed((nk_f32_t)distance_squared);
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
NK_PUBLIC void nk_angular_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
486
|
+
// Bias u8 [0,255] → i8 [-128,127] via XOR 0x80, then use the i8 magnitude+sign
|
|
487
|
+
// decomposition for saturation-safe relaxed_dot.
|
|
488
|
+
//
|
|
489
|
+
// The XOR-only approach (passing raw u8 as first operand) causes vpmaddubsw saturation:
|
|
490
|
+
// u8*i8 pairwise sums can reach 64770, exceeding i16 max (32767).
|
|
491
|
+
// Biasing first ensures i8*u7 products stay in [-16256, 16129], pairs in [-32512, 32258].
|
|
492
|
+
//
|
|
493
|
+
// Let a' = a - 128, b' = b - 128 (via XOR 0x80).
|
|
494
|
+
// Compute biased dots via relaxed_dot with i7 magnitude trick:
|
|
495
|
+
// a'·b' = relaxed_dot(a', b'&0x7F) - 128·Σ(a'[i] where b'[i]<0)
|
|
496
|
+
// Then recover true unsigned dots:
|
|
497
|
+
// a·b = a'·b' + 128·(Σa + Σb) - n·16384
|
|
498
|
+
// a·a = a'·a' + 256·Σa - n·16384
|
|
499
|
+
// b·b = b'·b' + 256·Σb - n·16384
|
|
500
|
+
nk_i64_t biased_ab = 0, biased_aa = 0, biased_bb = 0;
|
|
501
|
+
nk_i64_t sum_a_total = 0, sum_b_total = 0;
|
|
502
|
+
nk_size_t i = 0;
|
|
503
|
+
|
|
504
|
+
// Windowed accumulation loop
|
|
505
|
+
while (i + 16 <= n) {
|
|
506
|
+
v128_t dot_ab_i32x4 = wasm_i32x4_splat(0);
|
|
507
|
+
v128_t dot_aa_i32x4 = wasm_i32x4_splat(0);
|
|
508
|
+
v128_t dot_bb_i32x4 = wasm_i32x4_splat(0);
|
|
509
|
+
v128_t corr_ab_i16x8 = wasm_i16x8_splat(0);
|
|
510
|
+
v128_t corr_aa_i16x8 = wasm_i16x8_splat(0);
|
|
511
|
+
v128_t corr_bb_i16x8 = wasm_i16x8_splat(0);
|
|
512
|
+
v128_t sum_a_u16x8 = wasm_u16x8_splat(0);
|
|
513
|
+
v128_t sum_b_u16x8 = wasm_u16x8_splat(0);
|
|
514
|
+
|
|
515
|
+
// Inner loop: accumulate 127 iterations before widening corrections
|
|
516
|
+
// Overflow safety: max i16 lane = 127 × 254 = 32258 < 32767
|
|
517
|
+
nk_size_t cycle = 0;
|
|
518
|
+
for (; cycle < 127 && i + 16 <= n; ++cycle, i += 16) {
|
|
519
|
+
v128_t a_u8x16 = wasm_v128_load(a + i);
|
|
520
|
+
v128_t b_u8x16 = wasm_v128_load(b + i);
|
|
521
|
+
|
|
522
|
+
// Bias to signed: a' = a ^ 0x80, b' = b ^ 0x80
|
|
523
|
+
v128_t a_i8x16 = wasm_v128_xor(a_u8x16, wasm_i8x16_splat((char)0x80));
|
|
524
|
+
v128_t b_i8x16 = wasm_v128_xor(b_u8x16, wasm_i8x16_splat((char)0x80));
|
|
525
|
+
|
|
526
|
+
// Clear sign bit to get 7-bit unsigned magnitudes
|
|
527
|
+
v128_t a_7bit_u8x16 = wasm_v128_and(a_i8x16, wasm_i8x16_splat(0x7F));
|
|
528
|
+
v128_t b_7bit_u8x16 = wasm_v128_and(b_i8x16, wasm_i8x16_splat(0x7F));
|
|
529
|
+
|
|
530
|
+
// Negative masks for correction
|
|
531
|
+
v128_t a_neg_mask_i8x16 = wasm_i8x16_lt(a_i8x16, wasm_i8x16_splat(0));
|
|
532
|
+
v128_t b_neg_mask_i8x16 = wasm_i8x16_lt(b_i8x16, wasm_i8x16_splat(0));
|
|
533
|
+
|
|
534
|
+
// Three relaxed_dot calls on biased values
|
|
535
|
+
dot_ab_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, b_7bit_u8x16, dot_ab_i32x4);
|
|
536
|
+
dot_aa_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, a_7bit_u8x16, dot_aa_i32x4);
|
|
537
|
+
dot_bb_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(b_i8x16, b_7bit_u8x16, dot_bb_i32x4);
|
|
538
|
+
|
|
539
|
+
// Accumulate corrections in i16 (1 widening/iter instead of 2)
|
|
540
|
+
v128_t a_where_b_neg = wasm_v128_and(a_i8x16, b_neg_mask_i8x16);
|
|
541
|
+
v128_t a_where_a_neg = wasm_v128_and(a_i8x16, a_neg_mask_i8x16);
|
|
542
|
+
v128_t b_where_b_neg = wasm_v128_and(b_i8x16, b_neg_mask_i8x16);
|
|
543
|
+
corr_ab_i16x8 = wasm_i16x8_add(corr_ab_i16x8, wasm_i16x8_extadd_pairwise_i8x16(a_where_b_neg));
|
|
544
|
+
corr_aa_i16x8 = wasm_i16x8_add(corr_aa_i16x8, wasm_i16x8_extadd_pairwise_i8x16(a_where_a_neg));
|
|
545
|
+
corr_bb_i16x8 = wasm_i16x8_add(corr_bb_i16x8, wasm_i16x8_extadd_pairwise_i8x16(b_where_b_neg));
|
|
546
|
+
|
|
547
|
+
// Unsigned sums for final unbias correction
|
|
548
|
+
sum_a_u16x8 = wasm_i16x8_add(sum_a_u16x8, wasm_u16x8_extadd_pairwise_u8x16(a_u8x16));
|
|
549
|
+
sum_b_u16x8 = wasm_i16x8_add(sum_b_u16x8, wasm_u16x8_extadd_pairwise_u8x16(b_u8x16));
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
// Deferred widening: i16/u16 → i32/u32 once per window
|
|
553
|
+
v128_t corr_ab_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_ab_i16x8);
|
|
554
|
+
v128_t corr_aa_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_aa_i16x8);
|
|
555
|
+
v128_t corr_bb_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_bb_i16x8);
|
|
556
|
+
v128_t sum_a_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(sum_a_u16x8);
|
|
557
|
+
v128_t sum_b_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(sum_b_u16x8);
|
|
558
|
+
biased_ab += nk_reduce_add_i32x4_v128relaxed_(dot_ab_i32x4) -
|
|
559
|
+
128LL * nk_reduce_add_i32x4_v128relaxed_(corr_ab_i32x4);
|
|
560
|
+
biased_aa += nk_reduce_add_i32x4_v128relaxed_(dot_aa_i32x4) -
|
|
561
|
+
128LL * nk_reduce_add_i32x4_v128relaxed_(corr_aa_i32x4);
|
|
562
|
+
biased_bb += nk_reduce_add_i32x4_v128relaxed_(dot_bb_i32x4) -
|
|
563
|
+
128LL * nk_reduce_add_i32x4_v128relaxed_(corr_bb_i32x4);
|
|
564
|
+
sum_a_total += nk_reduce_add_u32x4_v128relaxed_(sum_a_u32x4);
|
|
565
|
+
sum_b_total += nk_reduce_add_u32x4_v128relaxed_(sum_b_u32x4);
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
// Scalar tail: compute biased products directly
|
|
569
|
+
for (; i < n; i++) {
|
|
570
|
+
nk_i32_t a_biased = (nk_i32_t)a[i] - 128;
|
|
571
|
+
nk_i32_t b_biased = (nk_i32_t)b[i] - 128;
|
|
572
|
+
biased_ab += (nk_i64_t)a_biased * b_biased;
|
|
573
|
+
biased_aa += (nk_i64_t)a_biased * a_biased;
|
|
574
|
+
biased_bb += (nk_i64_t)b_biased * b_biased;
|
|
575
|
+
sum_a_total += a[i];
|
|
576
|
+
sum_b_total += b[i];
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
// Recover true unsigned dots from biased:
|
|
580
|
+
// a·b = (a-128)·(b-128) + 128·Σa + 128·Σb - n·16384
|
|
581
|
+
// a·a = (a-128)·(a-128) + 256·Σa - n·16384
|
|
582
|
+
// b·b = (b-128)·(b-128) + 256·Σb - n·16384
|
|
583
|
+
nk_i64_t n_correction = (nk_i64_t)n * 16384LL;
|
|
584
|
+
nk_f64_t dot_ab = (nk_f64_t)(biased_ab + 128LL * (sum_a_total + sum_b_total) - n_correction);
|
|
585
|
+
nk_f64_t norm_aa = (nk_f64_t)(biased_aa + 256LL * sum_a_total - n_correction);
|
|
586
|
+
nk_f64_t norm_bb = (nk_f64_t)(biased_bb + 256LL * sum_b_total - n_correction);
|
|
587
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_(dot_ab, norm_aa, norm_bb);
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
NK_PUBLIC void nk_sqeuclidean_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
591
|
+
// XOR with 0x80 to reinterpret signed as unsigned, then use unsigned algorithm.
|
|
592
|
+
// |a-b|² is invariant under this uniform offset.
|
|
593
|
+
v128_t sum_u32x4 = wasm_u32x4_splat(0);
|
|
594
|
+
v128_t bias_u8x16 = wasm_u8x16_splat(0x80);
|
|
595
|
+
nk_i8_t const *a_scalars = a, *b_scalars = b;
|
|
596
|
+
nk_size_t count_scalars = n;
|
|
597
|
+
v128_t a_u8x16, b_u8x16;
|
|
598
|
+
|
|
599
|
+
nk_sqeuclidean_i8_v128relaxed_cycle:
|
|
600
|
+
if (count_scalars < 16) {
|
|
601
|
+
nk_b128_vec_t a_vec = {0}, b_vec = {0};
|
|
602
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
603
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
604
|
+
a_u8x16 = wasm_v128_xor(a_vec.v128, bias_u8x16);
|
|
605
|
+
b_u8x16 = wasm_v128_xor(b_vec.v128, bias_u8x16);
|
|
606
|
+
count_scalars = 0;
|
|
607
|
+
}
|
|
608
|
+
else {
|
|
609
|
+
a_u8x16 = wasm_v128_xor(wasm_v128_load(a_scalars), bias_u8x16);
|
|
610
|
+
b_u8x16 = wasm_v128_xor(wasm_v128_load(b_scalars), bias_u8x16);
|
|
611
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
v128_t difference_u8x16 = wasm_v128_or(wasm_u8x16_sub_sat(a_u8x16, b_u8x16), wasm_u8x16_sub_sat(b_u8x16, a_u8x16));
|
|
615
|
+
v128_t difference_low_u16x8 = wasm_u16x8_extend_low_u8x16(difference_u8x16);
|
|
616
|
+
v128_t difference_high_u16x8 = wasm_u16x8_extend_high_u8x16(difference_u8x16);
|
|
617
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_low_i16x8(difference_low_u16x8, difference_low_u16x8));
|
|
618
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_high_i16x8(difference_low_u16x8, difference_low_u16x8));
|
|
619
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_low_i16x8(difference_high_u16x8, difference_high_u16x8));
|
|
620
|
+
sum_u32x4 = wasm_i32x4_add(sum_u32x4, wasm_i32x4_extmul_high_i16x8(difference_high_u16x8, difference_high_u16x8));
|
|
621
|
+
if (count_scalars) goto nk_sqeuclidean_i8_v128relaxed_cycle;
|
|
622
|
+
|
|
623
|
+
*result = nk_reduce_add_u32x4_v128relaxed_(sum_u32x4);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
NK_PUBLIC void nk_euclidean_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
627
|
+
nk_u32_t distance_squared;
|
|
628
|
+
nk_sqeuclidean_i8_v128relaxed(a, b, n, &distance_squared);
|
|
629
|
+
*result = nk_f32_sqrt_v128relaxed((nk_f32_t)distance_squared);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
NK_PUBLIC void nk_angular_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
633
|
+
// Uses the same relaxed_dot decomposition as nk_dot_i8_v128relaxed:
|
|
634
|
+
// a·b = relaxed_dot(a, b&0x7F) - 128·Σ(a[i] where b[i]<0)
|
|
635
|
+
// a·a = relaxed_dot(a, a&0x7F) - 128·Σ(a[i] where a[i]<0)
|
|
636
|
+
// b·b = relaxed_dot(b, b&0x7F) - 128·Σ(b[i] where b[i]<0)
|
|
637
|
+
nk_i64_t dot_ab_total = 0, dot_aa_total = 0, dot_bb_total = 0;
|
|
638
|
+
nk_i64_t corr_ab_total = 0, corr_aa_total = 0, corr_bb_total = 0;
|
|
639
|
+
nk_size_t i = 0;
|
|
640
|
+
|
|
641
|
+
// Windowed accumulation loop
|
|
642
|
+
while (i + 16 <= n) {
|
|
643
|
+
v128_t dot_ab_i32x4 = wasm_i32x4_splat(0);
|
|
644
|
+
v128_t dot_aa_i32x4 = wasm_i32x4_splat(0);
|
|
645
|
+
v128_t dot_bb_i32x4 = wasm_i32x4_splat(0);
|
|
646
|
+
v128_t corr_ab_i16x8 = wasm_i16x8_splat(0); // accumulate corrections in i16
|
|
647
|
+
v128_t corr_aa_i16x8 = wasm_i16x8_splat(0);
|
|
648
|
+
v128_t corr_bb_i16x8 = wasm_i16x8_splat(0);
|
|
649
|
+
|
|
650
|
+
// Inner loop: accumulate 127 iterations before widening corrections
|
|
651
|
+
// Overflow safety: max i16 lane magnitude = 127 × 254 = 32258 < 32767
|
|
652
|
+
nk_size_t cycle = 0;
|
|
653
|
+
for (; cycle < 127 && i + 16 <= n; ++cycle, i += 16) {
|
|
654
|
+
v128_t a_i8x16 = wasm_v128_load(a + i);
|
|
655
|
+
v128_t b_i8x16 = wasm_v128_load(b + i);
|
|
656
|
+
|
|
657
|
+
// Clear sign bit to get 7-bit unsigned magnitudes
|
|
658
|
+
v128_t a_7bit_u8x16 = wasm_v128_and(a_i8x16, wasm_i8x16_splat(0x7F));
|
|
659
|
+
v128_t b_7bit_u8x16 = wasm_v128_and(b_i8x16, wasm_i8x16_splat(0x7F));
|
|
660
|
+
|
|
661
|
+
// Negative masks for correction
|
|
662
|
+
v128_t a_neg_mask_i8x16 = wasm_i8x16_lt(a_i8x16, wasm_i8x16_splat(0));
|
|
663
|
+
v128_t b_neg_mask_i8x16 = wasm_i8x16_lt(b_i8x16, wasm_i8x16_splat(0));
|
|
664
|
+
|
|
665
|
+
// Three relaxed_dot calls
|
|
666
|
+
dot_ab_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, b_7bit_u8x16, dot_ab_i32x4);
|
|
667
|
+
dot_aa_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, a_7bit_u8x16, dot_aa_i32x4);
|
|
668
|
+
dot_bb_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(b_i8x16, b_7bit_u8x16, dot_bb_i32x4);
|
|
669
|
+
|
|
670
|
+
// Accumulate corrections in i16 (1 widening/iter instead of 2)
|
|
671
|
+
v128_t a_where_b_neg = wasm_v128_and(a_i8x16, b_neg_mask_i8x16);
|
|
672
|
+
v128_t a_where_a_neg = wasm_v128_and(a_i8x16, a_neg_mask_i8x16);
|
|
673
|
+
v128_t b_where_b_neg = wasm_v128_and(b_i8x16, b_neg_mask_i8x16);
|
|
674
|
+
corr_ab_i16x8 = wasm_i16x8_add(corr_ab_i16x8, wasm_i16x8_extadd_pairwise_i8x16(a_where_b_neg));
|
|
675
|
+
corr_aa_i16x8 = wasm_i16x8_add(corr_aa_i16x8, wasm_i16x8_extadd_pairwise_i8x16(a_where_a_neg));
|
|
676
|
+
corr_bb_i16x8 = wasm_i16x8_add(corr_bb_i16x8, wasm_i16x8_extadd_pairwise_i8x16(b_where_b_neg));
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
// Deferred widening: i16 → i32 once per window
|
|
680
|
+
v128_t corr_ab_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_ab_i16x8);
|
|
681
|
+
v128_t corr_aa_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_aa_i16x8);
|
|
682
|
+
v128_t corr_bb_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(corr_bb_i16x8);
|
|
683
|
+
dot_ab_total += nk_reduce_add_i32x4_v128relaxed_(dot_ab_i32x4);
|
|
684
|
+
dot_aa_total += nk_reduce_add_i32x4_v128relaxed_(dot_aa_i32x4);
|
|
685
|
+
dot_bb_total += nk_reduce_add_i32x4_v128relaxed_(dot_bb_i32x4);
|
|
686
|
+
corr_ab_total += nk_reduce_add_i32x4_v128relaxed_(corr_ab_i32x4);
|
|
687
|
+
corr_aa_total += nk_reduce_add_i32x4_v128relaxed_(corr_aa_i32x4);
|
|
688
|
+
corr_bb_total += nk_reduce_add_i32x4_v128relaxed_(corr_bb_i32x4);
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
// Scalar tail
|
|
692
|
+
for (; i < n; i++) {
|
|
693
|
+
dot_ab_total += (nk_i32_t)a[i] * (nk_i32_t)b[i];
|
|
694
|
+
dot_aa_total += (nk_i32_t)a[i] * (nk_i32_t)a[i];
|
|
695
|
+
dot_bb_total += (nk_i32_t)b[i] * (nk_i32_t)b[i];
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
// Apply correction: true_dot = relaxed_dot - 128 × correction
|
|
699
|
+
// Scalar tail computes true products directly, so correction only applies to SIMD portion.
|
|
700
|
+
nk_f64_t dot_ab = (nk_f64_t)(dot_ab_total - 128LL * corr_ab_total);
|
|
701
|
+
nk_f64_t norm_aa = (nk_f64_t)(dot_aa_total - 128LL * corr_aa_total);
|
|
702
|
+
nk_f64_t norm_bb = (nk_f64_t)(dot_bb_total - 128LL * corr_bb_total);
|
|
703
|
+
*result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_(dot_ab, norm_aa, norm_bb);
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
#pragma endregion - Integer Spatial
|
|
707
|
+
|
|
708
|
+
#if defined(__clang__)
|
|
709
|
+
#pragma clang attribute pop
|
|
710
|
+
#endif
|
|
711
|
+
|
|
712
|
+
#if defined(__cplusplus)
|
|
713
|
+
} // extern "C"
|
|
714
|
+
#endif
|
|
715
|
+
|
|
716
|
+
#endif // NK_TARGET_V128RELAXED
|
|
717
|
+
#endif // NK_SPATIAL_V128RELAXED_H
|