numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Distances for RISC-V.
|
|
3
|
+
* @file include/numkong/curved/rvv.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 6, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements bilinear forms and Mahalanobis distance using RVV 1.0:
|
|
10
|
+
* - f32 inputs use f32 SIMD accumulation with vfredusum ordered reduction
|
|
11
|
+
* - f64 inputs use f64 SIMD accumulation with vfredusum ordered reduction
|
|
12
|
+
* - f16/bf16 inputs are converted to f32 via cast helpers, then accumulated in f32
|
|
13
|
+
* - Complex bilinear forms delegate to serial implementations
|
|
14
|
+
*/
|
|
15
|
+
#ifndef NK_CURVED_RVV_H
|
|
16
|
+
#define NK_CURVED_RVV_H
|
|
17
|
+
|
|
18
|
+
#if NK_TARGET_RISCV_
|
|
19
|
+
#if NK_TARGET_RVV
|
|
20
|
+
|
|
21
|
+
#include "numkong/types.h"
|
|
22
|
+
#include "numkong/curved/serial.h"
|
|
23
|
+
#include "numkong/cast/rvv.h"
|
|
24
|
+
#include "numkong/spatial/rvv.h" // `nk_f64_sqrt_rvv`
|
|
25
|
+
|
|
26
|
+
#if defined(__clang__)
|
|
27
|
+
#pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC target("arch=+v")
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
#if defined(__cplusplus)
|
|
34
|
+
extern "C" {
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
38
|
+
nk_f64_t *result) {
|
|
39
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
40
|
+
nk_f64_t outer_sum = 0;
|
|
41
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
42
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
43
|
+
nk_f32_t const *c_row = c + i * n;
|
|
44
|
+
nk_size_t remaining = n;
|
|
45
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
46
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
47
|
+
vfloat32m2_t c_f32m2 = __riscv_vle32_v_f32m2(c_row, vector_length);
|
|
48
|
+
vfloat32m2_t b_f32m2 = __riscv_vle32_v_f32m2(b + (n - remaining), vector_length);
|
|
49
|
+
inner_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(inner_f64m4, c_f32m2, b_f32m2, vector_length);
|
|
50
|
+
}
|
|
51
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
52
|
+
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
53
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
|
|
54
|
+
outer_sum += (nk_f64_t)a[i] * inner_val;
|
|
55
|
+
}
|
|
56
|
+
*result = outer_sum;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
60
|
+
nk_f64_t *result) {
|
|
61
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
62
|
+
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
63
|
+
nk_f64_t outer_compensation = 0;
|
|
64
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
65
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
66
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
67
|
+
nk_f64_t const *c_row = c + i * n;
|
|
68
|
+
nk_size_t remaining = n;
|
|
69
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
70
|
+
vector_length = __riscv_vsetvl_e64m4(remaining);
|
|
71
|
+
vfloat64m4_t vc_f64m4 = __riscv_vle64_v_f64m4(c_row, vector_length);
|
|
72
|
+
vfloat64m4_t vb_f64m4 = __riscv_vle64_v_f64m4(b + (n - remaining), vector_length);
|
|
73
|
+
vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(vc_f64m4, vb_f64m4, vector_length);
|
|
74
|
+
vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
|
|
75
|
+
vector_length);
|
|
76
|
+
vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(inner_f64m4, inner_f64m4, corrected_term_f64m4,
|
|
77
|
+
vector_length);
|
|
78
|
+
compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
79
|
+
compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, inner_f64m4, vector_length),
|
|
80
|
+
corrected_term_f64m4, vector_length);
|
|
81
|
+
inner_f64m4 = running_sum_f64m4;
|
|
82
|
+
}
|
|
83
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
84
|
+
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
85
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
|
|
86
|
+
nk_f64_t product_outer = a[i] * inner_val;
|
|
87
|
+
nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
|
|
88
|
+
nk_f64_t new_sum = old_sum + product_outer;
|
|
89
|
+
if (nk_f64_abs_(old_sum) >= nk_f64_abs_(product_outer))
|
|
90
|
+
outer_compensation += (old_sum - new_sum) + product_outer;
|
|
91
|
+
else outer_compensation += (product_outer - new_sum) + old_sum;
|
|
92
|
+
sum_f64m1 = __riscv_vfmv_v_f_f64m1(new_sum, 1);
|
|
93
|
+
}
|
|
94
|
+
*result = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1) + outer_compensation;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
98
|
+
nk_f32_t *result) {
|
|
99
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
100
|
+
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
101
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
102
|
+
// Convert a[i] from f16 to f32
|
|
103
|
+
nk_f32_t a_i;
|
|
104
|
+
nk_f16_to_f32_serial(a + i, &a_i);
|
|
105
|
+
|
|
106
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
107
|
+
nk_f16_t const *c_row = c + i * n;
|
|
108
|
+
nk_size_t remaining = n;
|
|
109
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
110
|
+
vector_length = __riscv_vsetvl_e16m1(remaining);
|
|
111
|
+
// Load f16 as u16 bits and convert to f32
|
|
112
|
+
vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
|
|
113
|
+
vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + (n - remaining)), vector_length);
|
|
114
|
+
vfloat32m2_t vc_f32m2 = nk_f16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
|
|
115
|
+
vfloat32m2_t vb_f32m2 = nk_f16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
|
|
116
|
+
inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, vb_f32m2, vector_length);
|
|
117
|
+
}
|
|
118
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
119
|
+
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
120
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
|
|
121
|
+
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
|
|
122
|
+
}
|
|
123
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
127
|
+
nk_f32_t *result) {
|
|
128
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
129
|
+
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
130
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
131
|
+
// Convert a[i] from bf16 to f32
|
|
132
|
+
nk_f32_t a_i;
|
|
133
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
134
|
+
|
|
135
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
136
|
+
nk_bf16_t const *c_row = c + i * n;
|
|
137
|
+
nk_size_t remaining = n;
|
|
138
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
139
|
+
vector_length = __riscv_vsetvl_e16m1(remaining);
|
|
140
|
+
// Load bf16 as u16 bits and convert to f32
|
|
141
|
+
vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
|
|
142
|
+
vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + (n - remaining)), vector_length);
|
|
143
|
+
vfloat32m2_t vc_f32m2 = nk_bf16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
|
|
144
|
+
vfloat32m2_t vb_f32m2 = nk_bf16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
|
|
145
|
+
inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, vb_f32m2, vector_length);
|
|
146
|
+
}
|
|
147
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
148
|
+
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
149
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
|
|
150
|
+
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
|
|
151
|
+
}
|
|
152
|
+
*result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
156
|
+
nk_f64_t *result) {
|
|
157
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
158
|
+
nk_f64_t outer_sum = 0;
|
|
159
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
160
|
+
nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
|
|
161
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
162
|
+
nk_f32_t const *c_row = c + i * n;
|
|
163
|
+
nk_size_t remaining = n;
|
|
164
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
165
|
+
vector_length = __riscv_vsetvl_e32m2(remaining);
|
|
166
|
+
nk_size_t j = n - remaining;
|
|
167
|
+
vfloat32m2_t c_f32m2 = __riscv_vle32_v_f32m2(c_row, vector_length);
|
|
168
|
+
vfloat32m2_t a_f32m2 = __riscv_vle32_v_f32m2(a + j, vector_length);
|
|
169
|
+
vfloat32m2_t b_f32m2 = __riscv_vle32_v_f32m2(b + j, vector_length);
|
|
170
|
+
vfloat64m4_t diff_f64m4 = __riscv_vfwsub_vv_f64m4(a_f32m2, b_f32m2, vector_length);
|
|
171
|
+
vfloat64m4_t c_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(c_f32m2, vector_length);
|
|
172
|
+
inner_f64m4 = __riscv_vfmacc_vv_f64m4_tu(inner_f64m4, c_f64m4, diff_f64m4, vector_length);
|
|
173
|
+
}
|
|
174
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
175
|
+
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
176
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
|
|
177
|
+
outer_sum += diff_i * inner_val;
|
|
178
|
+
}
|
|
179
|
+
*result = nk_f64_sqrt_rvv(outer_sum > 0 ? outer_sum : 0);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
183
|
+
nk_f64_t *result) {
|
|
184
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
|
|
185
|
+
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
186
|
+
nk_f64_t outer_compensation = 0;
|
|
187
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
188
|
+
nk_f64_t diff_i = a[i] - b[i];
|
|
189
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
190
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
|
|
191
|
+
nk_f64_t const *c_row = c + i * n;
|
|
192
|
+
nk_size_t remaining = n;
|
|
193
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
194
|
+
vector_length = __riscv_vsetvl_e64m4(remaining);
|
|
195
|
+
nk_size_t j = n - remaining;
|
|
196
|
+
vfloat64m4_t vc_f64m4 = __riscv_vle64_v_f64m4(c_row, vector_length);
|
|
197
|
+
vfloat64m4_t va_f64m4 = __riscv_vle64_v_f64m4(a + j, vector_length);
|
|
198
|
+
vfloat64m4_t vb_f64m4 = __riscv_vle64_v_f64m4(b + j, vector_length);
|
|
199
|
+
vfloat64m4_t diff_j_f64m4 = __riscv_vfsub_vv_f64m4(va_f64m4, vb_f64m4, vector_length);
|
|
200
|
+
vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(vc_f64m4, diff_j_f64m4, vector_length);
|
|
201
|
+
vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
|
|
202
|
+
vector_length);
|
|
203
|
+
vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(inner_f64m4, inner_f64m4, corrected_term_f64m4,
|
|
204
|
+
vector_length);
|
|
205
|
+
compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
|
|
206
|
+
compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, inner_f64m4, vector_length),
|
|
207
|
+
corrected_term_f64m4, vector_length);
|
|
208
|
+
inner_f64m4 = running_sum_f64m4;
|
|
209
|
+
}
|
|
210
|
+
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
211
|
+
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
212
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
|
|
213
|
+
nk_f64_t product_outer = diff_i * inner_val;
|
|
214
|
+
nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
|
|
215
|
+
nk_f64_t new_sum = old_sum + product_outer;
|
|
216
|
+
if (nk_f64_abs_(old_sum) >= nk_f64_abs_(product_outer))
|
|
217
|
+
outer_compensation += (old_sum - new_sum) + product_outer;
|
|
218
|
+
else outer_compensation += (product_outer - new_sum) + old_sum;
|
|
219
|
+
sum_f64m1 = __riscv_vfmv_v_f_f64m1(new_sum, 1);
|
|
220
|
+
}
|
|
221
|
+
nk_f64_t quadratic = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1) + outer_compensation;
|
|
222
|
+
*result = nk_f64_sqrt_rvv(quadratic > 0 ? quadratic : 0);
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
226
|
+
nk_f32_t *result) {
|
|
227
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
228
|
+
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
229
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
230
|
+
nk_f32_t a_i, b_i;
|
|
231
|
+
nk_f16_to_f32_serial(a + i, &a_i);
|
|
232
|
+
nk_f16_to_f32_serial(b + i, &b_i);
|
|
233
|
+
nk_f32_t diff_i = a_i - b_i;
|
|
234
|
+
|
|
235
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
236
|
+
nk_f16_t const *c_row = c + i * n;
|
|
237
|
+
nk_size_t remaining = n;
|
|
238
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
239
|
+
vector_length = __riscv_vsetvl_e16m1(remaining);
|
|
240
|
+
nk_size_t j = n - remaining;
|
|
241
|
+
vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
|
|
242
|
+
vuint16m1_t va_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(a + j), vector_length);
|
|
243
|
+
vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + j), vector_length);
|
|
244
|
+
vfloat32m2_t vc_f32m2 = nk_f16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
|
|
245
|
+
vfloat32m2_t va_f32m2 = nk_f16m1_to_f32m2_rvv_(va_u16m1, vector_length);
|
|
246
|
+
vfloat32m2_t vb_f32m2 = nk_f16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
|
|
247
|
+
vfloat32m2_t diff_j_f32m2 = __riscv_vfsub_vv_f32m2(va_f32m2, vb_f32m2, vector_length);
|
|
248
|
+
inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, diff_j_f32m2, vector_length);
|
|
249
|
+
}
|
|
250
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
251
|
+
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
252
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
|
|
253
|
+
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
|
|
254
|
+
}
|
|
255
|
+
nk_f32_t quadratic_f16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
256
|
+
*result = nk_f32_sqrt_rvv(quadratic_f16 > 0 ? quadratic_f16 : 0);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
260
|
+
nk_f32_t *result) {
|
|
261
|
+
nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
|
|
262
|
+
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
263
|
+
for (nk_size_t i = 0; i < n; ++i) {
|
|
264
|
+
nk_f32_t a_i, b_i;
|
|
265
|
+
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
266
|
+
nk_bf16_to_f32_serial(b + i, &b_i);
|
|
267
|
+
nk_f32_t diff_i = a_i - b_i;
|
|
268
|
+
|
|
269
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
|
|
270
|
+
nk_bf16_t const *c_row = c + i * n;
|
|
271
|
+
nk_size_t remaining = n;
|
|
272
|
+
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
273
|
+
vector_length = __riscv_vsetvl_e16m1(remaining);
|
|
274
|
+
nk_size_t j = n - remaining;
|
|
275
|
+
vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
|
|
276
|
+
vuint16m1_t va_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(a + j), vector_length);
|
|
277
|
+
vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + j), vector_length);
|
|
278
|
+
vfloat32m2_t vc_f32m2 = nk_bf16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
|
|
279
|
+
vfloat32m2_t va_f32m2 = nk_bf16m1_to_f32m2_rvv_(va_u16m1, vector_length);
|
|
280
|
+
vfloat32m2_t vb_f32m2 = nk_bf16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
|
|
281
|
+
vfloat32m2_t diff_j_f32m2 = __riscv_vfsub_vv_f32m2(va_f32m2, vb_f32m2, vector_length);
|
|
282
|
+
inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, diff_j_f32m2, vector_length);
|
|
283
|
+
}
|
|
284
|
+
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
285
|
+
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
286
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
|
|
287
|
+
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
|
|
288
|
+
}
|
|
289
|
+
nk_f32_t quadratic_bf16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
290
|
+
*result = nk_f32_sqrt_rvv(quadratic_bf16 > 0 ? quadratic_bf16 : 0);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
#if defined(__cplusplus)
|
|
294
|
+
} // extern "C"
|
|
295
|
+
#endif
|
|
296
|
+
|
|
297
|
+
#if defined(__clang__)
|
|
298
|
+
#pragma clang attribute pop
|
|
299
|
+
#elif defined(__GNUC__)
|
|
300
|
+
#pragma GCC pop_options
|
|
301
|
+
#endif
|
|
302
|
+
|
|
303
|
+
#endif // NK_TARGET_RVV
|
|
304
|
+
#endif // NK_TARGET_RISCV_
|
|
305
|
+
#endif // NK_CURVED_RVV_H
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated Curved Space Similarity for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/curved/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements bilinear forms and Mahalanobis distance with precision-appropriate strategies:
|
|
10
|
+
* - f64 inputs use Dot2 (Ogita-Rump-Oishi 2005) for error-free transformations
|
|
11
|
+
* - f32/f16/bf16 inputs upcast to wider accumulators (f64/f32), providing sufficient
|
|
12
|
+
* precision headroom without compensation overhead
|
|
13
|
+
*
|
|
14
|
+
* Bilinear form: aᵀ × C × b = Σᵢ aᵢ × (Σⱼ cᵢⱼ × bⱼ)
|
|
15
|
+
*
|
|
16
|
+
* The nested loop structure has two accumulation levels:
|
|
17
|
+
* - Inner: Σⱼ cᵢⱼ × bⱼ (O(n) terms per row)
|
|
18
|
+
* - Outer: Σᵢ aᵢ × inner_result (O(n) terms total)
|
|
19
|
+
*
|
|
20
|
+
* For f64→f64 (no upcast headroom): Dot2 uses TwoProd and TwoSum error-free
|
|
21
|
+
* transformations at both levels, capturing rounding errors in compensation terms.
|
|
22
|
+
*
|
|
23
|
+
* For upcasted types (f32→f64, f16→f32, bf16→f32): the wider accumulator provides
|
|
24
|
+
* enough extra mantissa bits that simple accumulation suffices.
|
|
25
|
+
*
|
|
26
|
+
* @see Ogita, T., Rump, S.M., Oishi, S. (2005). "Accurate Sum and Dot Product"
|
|
27
|
+
*/
|
|
28
|
+
#ifndef NK_CURVED_SERIAL_H
|
|
29
|
+
#define NK_CURVED_SERIAL_H
|
|
30
|
+
|
|
31
|
+
#include "numkong/types.h"
|
|
32
|
+
#include "numkong/spatial/serial.h" // `nk_f64_sqrt_serial`
|
|
33
|
+
|
|
34
|
+
#if defined(__cplusplus)
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* @brief Macro for bilinear form aᵀ × C × b with simple accumulation.
|
|
40
|
+
*
|
|
41
|
+
* Suitable for upcasted types where the wider accumulator provides sufficient
|
|
42
|
+
* precision headroom (f32→f64, f16→f32, bf16→f32).
|
|
43
|
+
*/
|
|
44
|
+
#define nk_define_bilinear_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
45
|
+
NK_PUBLIC void nk_bilinear_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
46
|
+
nk_##input_type##_t const *c, nk_size_t n, \
|
|
47
|
+
nk_##output_type##_t *result) { \
|
|
48
|
+
nk_##accumulator_type##_t outer_sum = 0; \
|
|
49
|
+
nk_##accumulator_type##_t vector_a_value, vector_b_value, tensor_value; \
|
|
50
|
+
for (nk_size_t row = 0; row != n; ++row) { \
|
|
51
|
+
nk_##accumulator_type##_t inner_sum = 0; \
|
|
52
|
+
load_and_convert(a + row, &vector_a_value); \
|
|
53
|
+
for (nk_size_t column = 0; column != n; ++column) { \
|
|
54
|
+
load_and_convert(b + column, &vector_b_value); \
|
|
55
|
+
load_and_convert(c + row * n + column, &tensor_value); \
|
|
56
|
+
inner_sum += tensor_value * vector_b_value; \
|
|
57
|
+
} \
|
|
58
|
+
outer_sum += vector_a_value * inner_sum; \
|
|
59
|
+
} \
|
|
60
|
+
*result = (nk_##output_type##_t)(outer_sum); \
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* @brief Macro for complex bilinear form aᵀ × C × b with simple accumulation.
|
|
65
|
+
*
|
|
66
|
+
* Suitable for upcasted complex types where the wider accumulator provides
|
|
67
|
+
* sufficient precision headroom.
|
|
68
|
+
*/
|
|
69
|
+
#define nk_define_bilinear_complex_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
70
|
+
NK_PUBLIC void nk_bilinear_##input_type##_serial( \
|
|
71
|
+
nk_##input_type##_t const *a_pairs, nk_##input_type##_t const *b_pairs, nk_##input_type##_t const *c_pairs, \
|
|
72
|
+
nk_size_t n, nk_##output_type##c_t *results) { \
|
|
73
|
+
nk_##accumulator_type##_t outer_sum_real = 0, outer_sum_imag = 0; \
|
|
74
|
+
nk_##accumulator_type##_t a_real, a_imag, b_real, b_imag, c_real, c_imag; \
|
|
75
|
+
for (nk_size_t row = 0; row != n; ++row) { \
|
|
76
|
+
nk_##accumulator_type##_t inner_sum_real = 0, inner_sum_imag = 0; \
|
|
77
|
+
load_and_convert(&(a_pairs + row)->real, &a_real); \
|
|
78
|
+
load_and_convert(&(a_pairs + row)->imag, &a_imag); \
|
|
79
|
+
for (nk_size_t column = 0; column != n; ++column) { \
|
|
80
|
+
load_and_convert(&(b_pairs + column)->real, &b_real); \
|
|
81
|
+
load_and_convert(&(b_pairs + column)->imag, &b_imag); \
|
|
82
|
+
load_and_convert(&(c_pairs + row * n + column)->real, &c_real); \
|
|
83
|
+
load_and_convert(&(c_pairs + row * n + column)->imag, &c_imag); \
|
|
84
|
+
inner_sum_real += c_real * b_real - c_imag * b_imag; \
|
|
85
|
+
inner_sum_imag += c_real * b_imag + c_imag * b_real; \
|
|
86
|
+
} \
|
|
87
|
+
/* Complex multiply: a_i * inner_result */ \
|
|
88
|
+
outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag; \
|
|
89
|
+
outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real; \
|
|
90
|
+
} \
|
|
91
|
+
results->real = outer_sum_real; \
|
|
92
|
+
results->imag = outer_sum_imag; \
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/**
|
|
96
|
+
* @brief Macro for Mahalanobis distance √((a−b)ᵀ × C × (a−b)) with simple accumulation.
|
|
97
|
+
*
|
|
98
|
+
* Suitable for upcasted types where the wider accumulator provides sufficient
|
|
99
|
+
* precision headroom. Differences are computed in the accumulator precision.
|
|
100
|
+
*/
|
|
101
|
+
#define nk_define_mahalanobis_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
102
|
+
NK_PUBLIC void nk_mahalanobis_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
103
|
+
nk_##input_type##_t const *c, nk_size_t n, \
|
|
104
|
+
nk_##output_type##_t *result) { \
|
|
105
|
+
nk_##accumulator_type##_t outer_sum = 0; \
|
|
106
|
+
nk_##accumulator_type##_t a_row_value, b_row_value, a_column_value, b_column_value, tensor_value; \
|
|
107
|
+
for (nk_size_t row = 0; row != n; ++row) { \
|
|
108
|
+
nk_##accumulator_type##_t inner_sum = 0; \
|
|
109
|
+
load_and_convert(a + row, &a_row_value); \
|
|
110
|
+
load_and_convert(b + row, &b_row_value); \
|
|
111
|
+
nk_##accumulator_type##_t difference_row = a_row_value - b_row_value; \
|
|
112
|
+
for (nk_size_t column = 0; column != n; ++column) { \
|
|
113
|
+
load_and_convert(a + column, &a_column_value); \
|
|
114
|
+
load_and_convert(b + column, &b_column_value); \
|
|
115
|
+
load_and_convert(c + row * n + column, &tensor_value); \
|
|
116
|
+
nk_##accumulator_type##_t difference_column = a_column_value - b_column_value; \
|
|
117
|
+
inner_sum += tensor_value * difference_column; \
|
|
118
|
+
} \
|
|
119
|
+
outer_sum += difference_row * inner_sum; \
|
|
120
|
+
} \
|
|
121
|
+
nk_##accumulator_type##_t quadratic = outer_sum; \
|
|
122
|
+
*result = nk_##accumulator_type##_sqrt_serial(quadratic > 0 ? quadratic : 0); \
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// f32 → f64 accumulator → f64 output
|
|
126
|
+
nk_define_bilinear_(f32, f64, f64, nk_assign_from_to_) // nk_bilinear_f32_serial
|
|
127
|
+
nk_define_bilinear_complex_(f32c, f64, f64, nk_assign_from_to_) // nk_bilinear_f32c_serial
|
|
128
|
+
nk_define_mahalanobis_(f32, f64, f64, nk_assign_from_to_) // nk_mahalanobis_f32_serial
|
|
129
|
+
|
|
130
|
+
// f16 → f32 accumulator → f32 output: f32 provides ample headroom for f16 (~3 vs ~7 decimal digits)
|
|
131
|
+
nk_define_bilinear_(f16, f32, f32, nk_f16_to_f32_serial) // nk_bilinear_f16_serial
|
|
132
|
+
nk_define_bilinear_complex_(f16c, f32, f32, nk_f16_to_f32_serial) // nk_bilinear_f16c_serial
|
|
133
|
+
nk_define_mahalanobis_(f16, f32, f32, nk_f16_to_f32_serial) // nk_mahalanobis_f16_serial
|
|
134
|
+
|
|
135
|
+
// bf16 → f32 accumulator → f32 output: f32 provides ample headroom for bf16 (~2 vs ~7 decimal digits)
|
|
136
|
+
nk_define_bilinear_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_bilinear_bf16_serial
|
|
137
|
+
nk_define_bilinear_complex_(bf16c, f32, f32, nk_bf16_to_f32_serial) // nk_bilinear_bf16c_serial
|
|
138
|
+
nk_define_mahalanobis_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_mahalanobis_bf16_serial
|
|
139
|
+
|
|
140
|
+
#undef nk_define_bilinear_
|
|
141
|
+
#undef nk_define_bilinear_complex_
|
|
142
|
+
#undef nk_define_mahalanobis_
|
|
143
|
+
|
|
144
|
+
NK_PUBLIC void nk_bilinear_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
145
|
+
nk_f64_t *result) {
|
|
146
|
+
nk_f64_t outer_sum = 0, outer_comp = 0;
|
|
147
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
148
|
+
nk_f64_t inner_sum = 0, inner_comp = 0;
|
|
149
|
+
for (nk_size_t col = 0; col != n; ++col) nk_f64_dot2_(&inner_sum, &inner_comp, c[row * n + col], b[col]);
|
|
150
|
+
nk_f64_t cb_j = inner_sum + inner_comp;
|
|
151
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
|
|
152
|
+
}
|
|
153
|
+
*result = outer_sum + outer_comp;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
NK_PUBLIC void nk_bilinear_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs,
|
|
157
|
+
nk_size_t n, nk_f64c_t *results) {
|
|
158
|
+
nk_f64_t outer_sum_real = 0, outer_comp_real = 0;
|
|
159
|
+
nk_f64_t outer_sum_imag = 0, outer_comp_imag = 0;
|
|
160
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
161
|
+
nk_f64_t a_real = a_pairs[row].real;
|
|
162
|
+
nk_f64_t a_imag = a_pairs[row].imag;
|
|
163
|
+
// 4 Dot2 accumulators for inner cross-terms
|
|
164
|
+
nk_f64_t sum_rr = 0, comp_rr = 0;
|
|
165
|
+
nk_f64_t sum_ii = 0, comp_ii = 0;
|
|
166
|
+
nk_f64_t sum_ri = 0, comp_ri = 0;
|
|
167
|
+
nk_f64_t sum_ir = 0, comp_ir = 0;
|
|
168
|
+
for (nk_size_t col = 0; col != n; ++col) {
|
|
169
|
+
nk_f64_t b_real = b_pairs[col].real, b_imag = b_pairs[col].imag;
|
|
170
|
+
nk_f64_t c_real = c_pairs[row * n + col].real, c_imag = c_pairs[row * n + col].imag;
|
|
171
|
+
nk_f64_dot2_(&sum_rr, &comp_rr, c_real, b_real);
|
|
172
|
+
nk_f64_dot2_(&sum_ii, &comp_ii, c_imag, b_imag);
|
|
173
|
+
nk_f64_dot2_(&sum_ri, &comp_ri, c_real, b_imag);
|
|
174
|
+
nk_f64_dot2_(&sum_ir, &comp_ir, c_imag, b_real);
|
|
175
|
+
}
|
|
176
|
+
nk_f64_t inner_real = (sum_rr + comp_rr) - (sum_ii + comp_ii);
|
|
177
|
+
nk_f64_t inner_imag = (sum_ri + comp_ri) + (sum_ir + comp_ir);
|
|
178
|
+
// Outer Dot2 complex multiply: a × inner
|
|
179
|
+
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
|
|
180
|
+
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, -a_imag, inner_imag);
|
|
181
|
+
nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_real, inner_imag);
|
|
182
|
+
nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_imag, inner_real);
|
|
183
|
+
}
|
|
184
|
+
results->real = outer_sum_real + outer_comp_real;
|
|
185
|
+
results->imag = outer_sum_imag + outer_comp_imag;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
NK_PUBLIC void nk_mahalanobis_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
189
|
+
nk_f64_t *result) {
|
|
190
|
+
nk_f64_t outer_sum = 0, outer_comp = 0;
|
|
191
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
192
|
+
nk_f64_t diff_row = a[row] - b[row];
|
|
193
|
+
nk_f64_t inner_sum = 0, inner_comp = 0;
|
|
194
|
+
for (nk_size_t col = 0; col != n; ++col)
|
|
195
|
+
nk_f64_dot2_(&inner_sum, &inner_comp, c[row * n + col], a[col] - b[col]);
|
|
196
|
+
nk_f64_t cb_j = inner_sum + inner_comp;
|
|
197
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
|
|
198
|
+
}
|
|
199
|
+
nk_f64_t quadratic = outer_sum + outer_comp;
|
|
200
|
+
*result = nk_f64_sqrt_serial(quadratic > 0 ? quadratic : 0);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
#if defined(__cplusplus)
|
|
204
|
+
} // extern "C"
|
|
205
|
+
#endif
|
|
206
|
+
|
|
207
|
+
#endif // NK_CURVED_SERIAL_H
|