numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for RISC-V BF16.
|
|
3
|
+
* @file include/numkong/spatial/rvvbf16.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 5, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* Zvfbfwma provides widening bf16 fused multiply-accumulate to f32:
|
|
10
|
+
* vfwmaccbf16: f32 ← bf16 × bf16
|
|
11
|
+
*
|
|
12
|
+
* For L2 distance, we use the identity: (a−b)² = a² + b² − 2 × a × b
|
|
13
|
+
* This allows us to use vfwmaccbf16 for all computations.
|
|
14
|
+
*
|
|
15
|
+
* Requires: RVV 1.0 + Zvfbfwma extension (GCC 14+ or Clang 18+)
|
|
16
|
+
*/
|
|
17
|
+
#ifndef NK_SPATIAL_RVVBF16_H
|
|
18
|
+
#define NK_SPATIAL_RVVBF16_H
|
|
19
|
+
|
|
20
|
+
#if NK_TARGET_RISCV_
|
|
21
|
+
#if NK_TARGET_RVVBF16
|
|
22
|
+
|
|
23
|
+
#include "numkong/types.h"
|
|
24
|
+
#include "numkong/spatial/rvv.h" // `nk_f32_sqrt_rvv`
|
|
25
|
+
|
|
26
|
+
#if defined(__clang__)
|
|
27
|
+
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfbfwma"))), apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC target("arch=+v,+zvfbfwma")
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
#if defined(__cplusplus)
|
|
34
|
+
extern "C" {
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
|
|
38
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
39
|
+
// Per-lane accumulators — deferred horizontal reduction
|
|
40
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
41
|
+
vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a² + b²
|
|
42
|
+
vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a × b
|
|
43
|
+
|
|
44
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
45
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
46
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
47
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
48
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
49
|
+
vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
|
|
50
|
+
vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
|
|
51
|
+
|
|
52
|
+
// Accumulate a², b², and a×b per-lane (no per-iteration reduction)
|
|
53
|
+
sq_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sq_sum_f32m2, a_bf16m1, a_bf16m1, vector_length);
|
|
54
|
+
sq_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sq_sum_f32m2, b_bf16m1, b_bf16m1, vector_length);
|
|
55
|
+
ab_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(ab_sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Single horizontal reduction after the loop
|
|
59
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
60
|
+
nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sq_sum_f32m2, zero_f32m1, vlmax));
|
|
61
|
+
nk_f32_t ab_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(ab_sum_f32m2, zero_f32m1, vlmax));
|
|
62
|
+
*result = sq_sum - 2.0f * ab_sum;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
NK_PUBLIC void nk_euclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
|
|
66
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
67
|
+
nk_sqeuclidean_bf16_rvvbf16(a_scalars, b_scalars, count_scalars, result);
|
|
68
|
+
// Handle potential negative values from floating point errors
|
|
69
|
+
*result = *result > 0.0f ? nk_f32_sqrt_rvv(*result) : 0.0f;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
73
|
+
nk_f32_t *result) {
|
|
74
|
+
// Per-lane accumulators — deferred horizontal reduction
|
|
75
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
76
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
77
|
+
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
78
|
+
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
79
|
+
|
|
80
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
81
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
82
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
83
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
84
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
85
|
+
vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
|
|
86
|
+
vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
|
|
87
|
+
|
|
88
|
+
// dot += a × b
|
|
89
|
+
dot_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(dot_f32m2, a_bf16m1, b_bf16m1, vector_length);
|
|
90
|
+
// a_sq += a × a
|
|
91
|
+
a_sq_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(a_sq_f32m2, a_bf16m1, a_bf16m1, vector_length);
|
|
92
|
+
// b_sq += b × b
|
|
93
|
+
b_sq_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(b_sq_f32m2, b_bf16m1, b_bf16m1, vector_length);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Single horizontal reduction after the loop
|
|
97
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
98
|
+
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
|
|
99
|
+
nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
|
|
100
|
+
nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
|
|
101
|
+
|
|
102
|
+
// Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
|
|
103
|
+
if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
|
|
104
|
+
else if (dot == 0.0f) { *result = 1.0f; }
|
|
105
|
+
else {
|
|
106
|
+
nk_f32_t unclipped = 1.0f - dot * nk_f32_rsqrt_rvv(a_sq) * nk_f32_rsqrt_rvv(b_sq);
|
|
107
|
+
*result = unclipped > 0.0f ? unclipped : 0.0f;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
#if defined(__cplusplus)
|
|
112
|
+
} // extern "C"
|
|
113
|
+
#endif
|
|
114
|
+
|
|
115
|
+
#if defined(__clang__)
|
|
116
|
+
#pragma clang attribute pop
|
|
117
|
+
#elif defined(__GNUC__)
|
|
118
|
+
#pragma GCC pop_options
|
|
119
|
+
#endif
|
|
120
|
+
|
|
121
|
+
#endif // NK_TARGET_RVVBF16
|
|
122
|
+
#endif // NK_TARGET_RISCV_
|
|
123
|
+
#endif // NK_SPATIAL_RVVBF16_H
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for RISC-V FP16.
|
|
3
|
+
* @file include/numkong/spatial/rvvhalf.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 5, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* Zvfh provides native half-precision (f16) vector operations.
|
|
10
|
+
* Uses widening operations (f16 → f32) for precision accumulation.
|
|
11
|
+
*
|
|
12
|
+
* Requires: RVV 1.0 + Zvfh extension (GCC 14+ or Clang 18+)
|
|
13
|
+
*/
|
|
14
|
+
#ifndef NK_SPATIAL_RVVHALF_H
|
|
15
|
+
#define NK_SPATIAL_RVVHALF_H
|
|
16
|
+
|
|
17
|
+
#if NK_TARGET_RISCV_
|
|
18
|
+
#if NK_TARGET_RVVHALF
|
|
19
|
+
|
|
20
|
+
#include "numkong/types.h"
|
|
21
|
+
#include "numkong/spatial/rvv.h" // `nk_f32_sqrt_rvv`
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfh"))), apply_to = function)
|
|
25
|
+
#elif defined(__GNUC__)
|
|
26
|
+
#pragma GCC push_options
|
|
27
|
+
#pragma GCC target("arch=+v,+zvfh")
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
#if defined(__cplusplus)
|
|
31
|
+
extern "C" {
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
35
|
+
nk_f32_t *result) {
|
|
36
|
+
// Per-lane accumulator — deferred horizontal reduction
|
|
37
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
38
|
+
vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
39
|
+
|
|
40
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
41
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
42
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
43
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
44
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
45
|
+
vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
|
|
46
|
+
vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
|
|
47
|
+
// Upcast to f32 before subtraction to avoid catastrophic cancellation in f16
|
|
48
|
+
vfloat32m2_t a_f32m2 = __riscv_vfwcvt_f_f_v_f32m2(a_f16m1, vector_length);
|
|
49
|
+
vfloat32m2_t b_f32m2 = __riscv_vfwcvt_f_f_v_f32m2(b_f16m1, vector_length);
|
|
50
|
+
vfloat32m2_t diff_f32m2 = __riscv_vfsub_vv_f32m2(a_f32m2, b_f32m2, vector_length);
|
|
51
|
+
// Accumulate diff² in f32
|
|
52
|
+
sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, diff_f32m2, diff_f32m2, vector_length);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// Single horizontal reduction after the loop
|
|
56
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
57
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
61
|
+
nk_f32_t *result) {
|
|
62
|
+
nk_sqeuclidean_f16_rvvhalf(a_scalars, b_scalars, count_scalars, result);
|
|
63
|
+
*result = nk_f32_sqrt_rvv(*result);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
67
|
+
nk_f32_t *result) {
|
|
68
|
+
// Per-lane accumulators — deferred horizontal reduction
|
|
69
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
70
|
+
vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
71
|
+
vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
72
|
+
vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
73
|
+
|
|
74
|
+
for (nk_size_t vector_length; count_scalars > 0;
|
|
75
|
+
count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
|
|
76
|
+
vector_length = __riscv_vsetvl_e16m1(count_scalars);
|
|
77
|
+
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
|
|
78
|
+
vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
|
|
79
|
+
vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
|
|
80
|
+
vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
|
|
81
|
+
|
|
82
|
+
// dot += a × b (widened to f32)
|
|
83
|
+
dot_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(dot_f32m2, a_f16m1, b_f16m1, vector_length);
|
|
84
|
+
// a_sq += a × a
|
|
85
|
+
a_sq_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(a_sq_f32m2, a_f16m1, a_f16m1, vector_length);
|
|
86
|
+
// b_sq += b × b
|
|
87
|
+
b_sq_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(b_sq_f32m2, b_f16m1, b_f16m1, vector_length);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Single horizontal reduction after the loop
|
|
91
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
92
|
+
nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
|
|
93
|
+
nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
|
|
94
|
+
nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
|
|
95
|
+
|
|
96
|
+
// Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
|
|
97
|
+
if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
|
|
98
|
+
else if (dot == 0.0f) { *result = 1.0f; }
|
|
99
|
+
else {
|
|
100
|
+
nk_f32_t unclipped = 1.0f - dot * nk_f32_rsqrt_rvv(a_sq) * nk_f32_rsqrt_rvv(b_sq);
|
|
101
|
+
*result = unclipped > 0.0f ? unclipped : 0.0f;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
#if defined(__cplusplus)
|
|
106
|
+
} // extern "C"
|
|
107
|
+
#endif
|
|
108
|
+
|
|
109
|
+
#if defined(__clang__)
|
|
110
|
+
#pragma clang attribute pop
|
|
111
|
+
#elif defined(__GNUC__)
|
|
112
|
+
#pragma GCC pop_options
|
|
113
|
+
#endif
|
|
114
|
+
|
|
115
|
+
#endif // NK_TARGET_RVVHALF
|
|
116
|
+
#endif // NK_TARGET_RISCV_
|
|
117
|
+
#endif // NK_SPATIAL_RVVHALF_H
|
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Sapphire Rapids.
|
|
3
|
+
* @file include/numkong/spatial/sapphire.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* Sapphire Rapids adds native FP16 support via AVX-512 FP16 extension.
|
|
10
|
+
* For e4m3 L2 distance, we can leverage F16 for the subtraction step:
|
|
11
|
+
* - e4m3 differences fit in F16 (max |a−b| = 896 < 65504)
|
|
12
|
+
* - But squared differences overflow F16 (896² = 802816 > 65504)
|
|
13
|
+
* - So: subtract in F16, convert to F32, then square and accumulate
|
|
14
|
+
*
|
|
15
|
+
* For e2m3/e3m2 L2 distance, squared differences fit in FP16:
|
|
16
|
+
* - E2M3: max |a−b| = 15, max (a−b)² = 225 < 65504, flush cadence = 4 (conservative for uniformity)
|
|
17
|
+
* - E3M2: max |a−b| = 56, max (a−b)² = 3136 < 65504, flush cadence = 4
|
|
18
|
+
* So the entire sub+square+accumulate stays in FP16 with periodic F32 flush.
|
|
19
|
+
*
|
|
20
|
+
* @section spatial_sapphire_instructions Relevant Instructions
|
|
21
|
+
*
|
|
22
|
+
* Intrinsic Instruction Sapphire Genoa
|
|
23
|
+
* _mm256_sub_ph VSUBPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
|
|
24
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p01
|
|
25
|
+
* _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
26
|
+
* _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy ~8cy
|
|
27
|
+
* _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
|
|
28
|
+
*/
|
|
29
|
+
#ifndef NK_SPATIAL_SAPPHIRE_H
|
|
30
|
+
#define NK_SPATIAL_SAPPHIRE_H
|
|
31
|
+
|
|
32
|
+
#if NK_TARGET_X86_
|
|
33
|
+
#if NK_TARGET_SAPPHIRE
|
|
34
|
+
|
|
35
|
+
#include "numkong/types.h"
|
|
36
|
+
#include "numkong/cast/sapphire.h" // `nk_e4m3x16_to_f16x16_sapphire_`
|
|
37
|
+
#include "numkong/dot/sapphire.h" // `nk_e2m3x32_to_f16x32_sapphire_`, `nk_flush_f16_to_f32_sapphire_`
|
|
38
|
+
#include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
|
|
39
|
+
|
|
40
|
+
#if defined(__cplusplus)
|
|
41
|
+
extern "C" {
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
#if defined(__clang__)
|
|
45
|
+
#pragma clang attribute push( \
|
|
46
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
|
|
47
|
+
apply_to = function)
|
|
48
|
+
#elif defined(__GNUC__)
|
|
49
|
+
#pragma GCC push_options
|
|
50
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
|
|
54
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
55
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
56
|
+
|
|
57
|
+
while (count_scalars > 0) {
|
|
58
|
+
nk_size_t const n = count_scalars < 16 ? count_scalars : 16;
|
|
59
|
+
__mmask16 const mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
60
|
+
__m128i a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
61
|
+
__m128i b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
62
|
+
|
|
63
|
+
// Convert e4m3 → f16
|
|
64
|
+
__m256h a_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(a_e4m3x16);
|
|
65
|
+
__m256h b_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(b_e4m3x16);
|
|
66
|
+
|
|
67
|
+
// Subtract in F16 − differences fit (max 896 < 65504)
|
|
68
|
+
__m256h diff_f16x16 = _mm256_sub_ph(a_f16x16, b_f16x16);
|
|
69
|
+
|
|
70
|
+
// Convert to F32 before squaring (896² = 802816 overflows F16!)
|
|
71
|
+
__m512 diff_f32x16 = _mm512_cvtph_ps(_mm256_castph_si256(diff_f16x16));
|
|
72
|
+
|
|
73
|
+
// Square and accumulate in F32
|
|
74
|
+
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
75
|
+
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
*result = _mm512_reduce_add_ps(sum_f32x16);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
NK_PUBLIC void nk_euclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
|
|
82
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
83
|
+
nk_sqeuclidean_e4m3_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
84
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
88
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
89
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
90
|
+
|
|
91
|
+
// Main loop: 4-way unrolled, 128 elements per flush
|
|
92
|
+
while (count_scalars >= 128) {
|
|
93
|
+
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
94
|
+
__m512h a_f16x32, b_f16x32, diff_f16x32;
|
|
95
|
+
// Iteration 1
|
|
96
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
97
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
98
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
99
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
100
|
+
// Iteration 2
|
|
101
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
102
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
103
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
104
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
105
|
+
// Iteration 3
|
|
106
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
107
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
108
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
109
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
110
|
+
// Iteration 4
|
|
111
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
112
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
113
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
114
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
115
|
+
// Flush to F32
|
|
116
|
+
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
117
|
+
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
121
|
+
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
122
|
+
while (count_scalars > 0) {
|
|
123
|
+
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
124
|
+
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
125
|
+
__m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
126
|
+
__m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
127
|
+
__m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
128
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
129
|
+
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
130
|
+
}
|
|
131
|
+
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
132
|
+
|
|
133
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
137
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
138
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
139
|
+
|
|
140
|
+
// Main loop: 4-way unrolled, 128 elements per flush
|
|
141
|
+
while (count_scalars >= 128) {
|
|
142
|
+
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
143
|
+
__m512h a_f16x32, b_f16x32, diff_f16x32;
|
|
144
|
+
// Iteration 1
|
|
145
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
146
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
147
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
148
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
149
|
+
// Iteration 2
|
|
150
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
151
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
152
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
153
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
154
|
+
// Iteration 3
|
|
155
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
156
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
157
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
158
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
159
|
+
// Iteration 4
|
|
160
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
161
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
162
|
+
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
163
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
164
|
+
// Flush to F32
|
|
165
|
+
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
166
|
+
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
170
|
+
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
171
|
+
while (count_scalars > 0) {
|
|
172
|
+
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
173
|
+
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
174
|
+
__m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
175
|
+
__m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
176
|
+
__m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
177
|
+
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
178
|
+
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
179
|
+
}
|
|
180
|
+
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
181
|
+
|
|
182
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
NK_PUBLIC void nk_euclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
186
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
187
|
+
nk_sqeuclidean_e2m3_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
188
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
NK_PUBLIC void nk_euclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
192
|
+
nk_size_t count_scalars, nk_f32_t *result) {
|
|
193
|
+
nk_sqeuclidean_e3m2_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
194
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
NK_PUBLIC void nk_angular_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
198
|
+
nk_f32_t *result) {
|
|
199
|
+
__m512 sum_dot_f32x16 = _mm512_setzero_ps();
|
|
200
|
+
__m512 sum_a_f32x16 = _mm512_setzero_ps();
|
|
201
|
+
__m512 sum_b_f32x16 = _mm512_setzero_ps();
|
|
202
|
+
|
|
203
|
+
// Main loop: 4-way unrolled, 128 elements per flush
|
|
204
|
+
while (count_scalars >= 128) {
|
|
205
|
+
__m512h dot_acc = _mm512_setzero_ph();
|
|
206
|
+
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
207
|
+
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
208
|
+
__m512h a_f16x32, b_f16x32;
|
|
209
|
+
// Iteration 1
|
|
210
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
211
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
212
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
213
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
214
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
215
|
+
// Iteration 2
|
|
216
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
217
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
218
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
219
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
220
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
221
|
+
// Iteration 3
|
|
222
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
223
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
224
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
225
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
226
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
227
|
+
// Iteration 4
|
|
228
|
+
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
229
|
+
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
230
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
231
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
232
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
233
|
+
// Flush to F32
|
|
234
|
+
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
235
|
+
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
236
|
+
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
237
|
+
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
241
|
+
__m512h dot_acc = _mm512_setzero_ph();
|
|
242
|
+
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
243
|
+
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
244
|
+
while (count_scalars > 0) {
|
|
245
|
+
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
246
|
+
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
247
|
+
__m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
248
|
+
__m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
249
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
250
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
251
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
252
|
+
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
253
|
+
}
|
|
254
|
+
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
255
|
+
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
256
|
+
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
257
|
+
|
|
258
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
|
|
259
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
|
|
260
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
|
|
261
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
NK_PUBLIC void nk_angular_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
265
|
+
nk_f32_t *result) {
|
|
266
|
+
__m512 sum_dot_f32x16 = _mm512_setzero_ps();
|
|
267
|
+
__m512 sum_a_f32x16 = _mm512_setzero_ps();
|
|
268
|
+
__m512 sum_b_f32x16 = _mm512_setzero_ps();
|
|
269
|
+
|
|
270
|
+
// Main loop: 4-way unrolled, 128 elements per flush
|
|
271
|
+
while (count_scalars >= 128) {
|
|
272
|
+
__m512h dot_acc = _mm512_setzero_ph();
|
|
273
|
+
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
274
|
+
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
275
|
+
__m512h a_f16x32, b_f16x32;
|
|
276
|
+
// Iteration 1
|
|
277
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
278
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
279
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
280
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
281
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
282
|
+
// Iteration 2
|
|
283
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
284
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
285
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
286
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
287
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
288
|
+
// Iteration 3
|
|
289
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
290
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
291
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
292
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
293
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
294
|
+
// Iteration 4
|
|
295
|
+
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
296
|
+
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
297
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
298
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
299
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
300
|
+
// Flush to F32
|
|
301
|
+
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
302
|
+
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
303
|
+
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
304
|
+
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
308
|
+
__m512h dot_acc = _mm512_setzero_ph();
|
|
309
|
+
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
310
|
+
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
311
|
+
while (count_scalars > 0) {
|
|
312
|
+
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
313
|
+
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
314
|
+
__m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
315
|
+
__m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
316
|
+
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
317
|
+
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
318
|
+
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
319
|
+
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
320
|
+
}
|
|
321
|
+
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
322
|
+
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
323
|
+
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
324
|
+
|
|
325
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
|
|
326
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
|
|
327
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
|
|
328
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
#if defined(__clang__)
|
|
332
|
+
#pragma clang attribute pop
|
|
333
|
+
#elif defined(__GNUC__)
|
|
334
|
+
#pragma GCC pop_options
|
|
335
|
+
#endif
|
|
336
|
+
|
|
337
|
+
#if defined(__cplusplus)
|
|
338
|
+
} // extern "C"
|
|
339
|
+
#endif
|
|
340
|
+
|
|
341
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
342
|
+
#endif // NK_TARGET_X86_
|
|
343
|
+
#endif // NK_SPATIAL_SAPPHIRE_H
|