numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for NEON BF16.
|
|
3
|
+
* @file include/numkong/curved/neonbfdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements bf16 bilinear forms and Mahalanobis distance using ARM NEON with BF16 extensions.
|
|
10
|
+
*
|
|
11
|
+
* @section curved_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction Latency Throughput
|
|
14
|
+
* A76 M4+/V1+/Oryon
|
|
15
|
+
* vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
|
|
16
|
+
* vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
|
|
17
|
+
* vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
|
|
18
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
19
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
20
|
+
*
|
|
21
|
+
* For bilinear forms, BFDOT enables efficient inner-product computation by processing 8 bf16
|
|
22
|
+
* pairs into 4 f32 results per instruction. For Mahalanobis distance, bf16 inputs are converted
|
|
23
|
+
* to f32 for subtraction, then accumulated using FMA for numerical stability.
|
|
24
|
+
*/
|
|
25
|
+
#ifndef NK_CURVED_NEONBFDOT_H
|
|
26
|
+
#define NK_CURVED_NEONBFDOT_H
|
|
27
|
+
|
|
28
|
+
#if NK_TARGET_ARM_
|
|
29
|
+
#if NK_TARGET_NEONBFDOT
|
|
30
|
+
|
|
31
|
+
#include "numkong/types.h" // `nk_bf16_t`
|
|
32
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
33
|
+
#include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#if defined(__clang__)
|
|
40
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
|
|
41
|
+
#elif defined(__GNUC__)
|
|
42
|
+
#pragma GCC push_options
|
|
43
|
+
#pragma GCC target("arch=armv8.6-a+simd+bf16")
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
NK_PUBLIC void nk_bilinear_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
47
|
+
nk_f32_t *result) {
|
|
48
|
+
float32x4_t outer_sum_f32x4 = vdupq_n_f32(0);
|
|
49
|
+
|
|
50
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
51
|
+
// Load a[i] and broadcast to f32
|
|
52
|
+
nk_f32_t a_i_f32;
|
|
53
|
+
nk_bf16_to_f32_serial(a + i, &a_i_f32);
|
|
54
|
+
float32x4_t a_i_f32x4 = vdupq_n_f32(a_i_f32);
|
|
55
|
+
|
|
56
|
+
// Inner sum
|
|
57
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
58
|
+
nk_size_t j = 0;
|
|
59
|
+
|
|
60
|
+
// Process 8 elements at a time using BFDOT
|
|
61
|
+
for (; j + 8 <= n; j += 8) {
|
|
62
|
+
bfloat16x8_t b_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)(b + j));
|
|
63
|
+
bfloat16x8_t c_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)(c + i * n + j));
|
|
64
|
+
inner_sum_f32x4 = vbfdotq_f32(inner_sum_f32x4, c_bf16x8, b_bf16x8);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// Handle tail elements (less than 8)
|
|
68
|
+
if (j < n) {
|
|
69
|
+
nk_b128_vec_t b_vec, c_vec;
|
|
70
|
+
nk_partial_load_b16x8_serial_(b + j, &b_vec, n - j);
|
|
71
|
+
nk_partial_load_b16x8_serial_(c + i * n + j, &c_vec, n - j);
|
|
72
|
+
bfloat16x8_t b_bf16x8 = vreinterpretq_bf16_u16(b_vec.u16x8);
|
|
73
|
+
bfloat16x8_t c_bf16x8 = vreinterpretq_bf16_u16(c_vec.u16x8);
|
|
74
|
+
inner_sum_f32x4 = vbfdotq_f32(inner_sum_f32x4, c_bf16x8, b_bf16x8);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Accumulate: outer_sum += a[i] * inner_sum
|
|
78
|
+
outer_sum_f32x4 = vfmaq_f32(outer_sum_f32x4, a_i_f32x4, inner_sum_f32x4);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
*result = vaddvq_f32(outer_sum_f32x4);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
NK_PUBLIC void nk_mahalanobis_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
85
|
+
nk_f32_t *result) {
|
|
86
|
+
nk_f32_t outer_sum = 0;
|
|
87
|
+
|
|
88
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
89
|
+
// Compute diff_i = a[i] - b[i] in f32
|
|
90
|
+
nk_f32_t a_i_f32, b_i_f32;
|
|
91
|
+
nk_bf16_to_f32_serial(a + i, &a_i_f32);
|
|
92
|
+
nk_bf16_to_f32_serial(b + i, &b_i_f32);
|
|
93
|
+
nk_f32_t diff_i = a_i_f32 - b_i_f32;
|
|
94
|
+
|
|
95
|
+
// Inner sum
|
|
96
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
97
|
+
nk_size_t j = 0;
|
|
98
|
+
|
|
99
|
+
// Process 4 elements at a time (convert bf16->f32, subtract, then FMA)
|
|
100
|
+
for (; j + 4 <= n; j += 4) {
|
|
101
|
+
bfloat16x4_t a_j_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(a + j));
|
|
102
|
+
bfloat16x4_t b_j_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(b + j));
|
|
103
|
+
bfloat16x4_t c_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(c + i * n + j));
|
|
104
|
+
|
|
105
|
+
float32x4_t a_j_f32x4 = vcvt_f32_bf16(a_j_bf16x4);
|
|
106
|
+
float32x4_t b_j_f32x4 = vcvt_f32_bf16(b_j_bf16x4);
|
|
107
|
+
float32x4_t c_f32x4 = vcvt_f32_bf16(c_bf16x4);
|
|
108
|
+
|
|
109
|
+
float32x4_t diff_j_f32x4 = vsubq_f32(a_j_f32x4, b_j_f32x4);
|
|
110
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_j_f32x4);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// Handle tail elements
|
|
114
|
+
nk_f32_t inner_sum_tail = 0;
|
|
115
|
+
for (; j < n; ++j) {
|
|
116
|
+
nk_f32_t a_j_f32, b_j_f32, c_f32;
|
|
117
|
+
nk_bf16_to_f32_serial(a + j, &a_j_f32);
|
|
118
|
+
nk_bf16_to_f32_serial(b + j, &b_j_f32);
|
|
119
|
+
nk_bf16_to_f32_serial(c + i * n + j, &c_f32);
|
|
120
|
+
inner_sum_tail += c_f32 * (a_j_f32 - b_j_f32);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Reduce inner sum and add tail
|
|
124
|
+
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4) + inner_sum_tail;
|
|
125
|
+
|
|
126
|
+
// Accumulate: outer_sum += diff_i * inner_sum
|
|
127
|
+
outer_sum += diff_i * inner_sum;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
nk_f32_t quadratic = outer_sum;
|
|
131
|
+
*result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
NK_PUBLIC void nk_bilinear_bf16c_neonbfdot(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs,
|
|
135
|
+
nk_bf16c_t const *c_pairs, nk_size_t n, nk_f32c_t *result) {
|
|
136
|
+
// ARMv8.3-A FCMLA was benchmarked for this complex multiply pattern.
|
|
137
|
+
// The deinterleave+4FMA approach is 2.3x faster on Apple M4 — see `dot/neon.h` comment.
|
|
138
|
+
nk_f32_t outer_sum_real = 0;
|
|
139
|
+
nk_f32_t outer_sum_imag = 0;
|
|
140
|
+
|
|
141
|
+
for (nk_size_t i = 0; i != n; ++i) {
|
|
142
|
+
// Load a[i] as complex (real, imag) and convert to f32
|
|
143
|
+
nk_f32_t a_real, a_imag;
|
|
144
|
+
nk_bf16_to_f32_serial(&a_pairs[i].real, &a_real);
|
|
145
|
+
nk_bf16_to_f32_serial(&a_pairs[i].imag, &a_imag);
|
|
146
|
+
|
|
147
|
+
// Inner sums for real and imaginary parts of c[i,j] * b[j]
|
|
148
|
+
float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
|
|
149
|
+
float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
|
|
150
|
+
nk_size_t j = 0;
|
|
151
|
+
|
|
152
|
+
// Process 4 complex pairs at a time
|
|
153
|
+
for (; j + 4 <= n; j += 4) {
|
|
154
|
+
// Deinterleave load: separate real and imaginary parts
|
|
155
|
+
// MSVC doesn't support vld2_bf16, so load as s16 and reinterpret
|
|
156
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + j));
|
|
157
|
+
int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_pairs + i * n + j));
|
|
158
|
+
|
|
159
|
+
float32x4_t b_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[0]));
|
|
160
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[1]));
|
|
161
|
+
float32x4_t c_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(c_i16x4x2.val[0]));
|
|
162
|
+
float32x4_t c_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(c_i16x4x2.val[1]));
|
|
163
|
+
|
|
164
|
+
// Complex multiply: c * b = (c_real*b_real - c_imag*b_imag) + (c_real*b_imag + c_imag*b_real)*i
|
|
165
|
+
// Real part: c_real*b_real - c_imag*b_imag
|
|
166
|
+
inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
|
|
167
|
+
inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
|
|
168
|
+
// Imaginary part: c_real*b_imag + c_imag*b_real
|
|
169
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
|
|
170
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
// Handle tail elements
|
|
174
|
+
nk_f32_t inner_sum_real_tail = 0, inner_sum_imag_tail = 0;
|
|
175
|
+
for (; j < n; ++j) {
|
|
176
|
+
nk_f32_t b_real, b_imag, c_real, c_imag;
|
|
177
|
+
nk_bf16_to_f32_serial(&b_pairs[j].real, &b_real);
|
|
178
|
+
nk_bf16_to_f32_serial(&b_pairs[j].imag, &b_imag);
|
|
179
|
+
nk_bf16_to_f32_serial(&c_pairs[i * n + j].real, &c_real);
|
|
180
|
+
nk_bf16_to_f32_serial(&c_pairs[i * n + j].imag, &c_imag);
|
|
181
|
+
// Complex multiply: c * b
|
|
182
|
+
inner_sum_real_tail += c_real * b_real - c_imag * b_imag;
|
|
183
|
+
inner_sum_imag_tail += c_real * b_imag + c_imag * b_real;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// Reduce inner sums
|
|
187
|
+
nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4) + inner_sum_real_tail;
|
|
188
|
+
nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4) + inner_sum_imag_tail;
|
|
189
|
+
|
|
190
|
+
// Complex multiply: a * inner_sum = (a_real*inner_real - a_imag*inner_imag) + (a_real*inner_imag +
|
|
191
|
+
// a_imag*inner_real)*i
|
|
192
|
+
outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
|
|
193
|
+
outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
result->real = outer_sum_real;
|
|
197
|
+
result->imag = outer_sum_imag;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
#if defined(__clang__)
|
|
201
|
+
#pragma clang attribute pop
|
|
202
|
+
#elif defined(__GNUC__)
|
|
203
|
+
#pragma GCC pop_options
|
|
204
|
+
#endif
|
|
205
|
+
|
|
206
|
+
#if defined(__cplusplus)
|
|
207
|
+
} // extern "C"
|
|
208
|
+
#endif
|
|
209
|
+
|
|
210
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
211
|
+
#endif // NK_TARGET_ARM_
|
|
212
|
+
#endif // NK_CURVED_NEONBFDOT_H
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for NEON FP16.
|
|
3
|
+
* @file include/numkong/curved/neonhalf.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements f16 bilinear forms and Mahalanobis distance using ARM NEON with FP16 extensions.
|
|
10
|
+
*
|
|
11
|
+
* @section curved_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction Latency Throughput
|
|
14
|
+
* A76 M4+/V1+/Oryon
|
|
15
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
16
|
+
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
|
|
17
|
+
* vld1_f16 LD1 (V.4H) 4cy 2/cy 3/cy
|
|
18
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
|
|
19
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
20
|
+
*
|
|
21
|
+
* Bilinear forms involve nested summation O(n^2) operations. For numerical stability,
|
|
22
|
+
* f16 inputs are widened to f32 for accumulation. The matrix C is accessed row-by-row
|
|
23
|
+
* to maintain cache locality.
|
|
24
|
+
*
|
|
25
|
+
* Mathematical definitions:
|
|
26
|
+
* - Bilinear: result = ∑ᵢ ∑ⱼ aᵢ × cᵢⱼ × bⱼ
|
|
27
|
+
* - Mahalanobis: result = √((a - b)ᵀ × C × (a - b))
|
|
28
|
+
*/
|
|
29
|
+
#ifndef NK_CURVED_NEONHALF_H
|
|
30
|
+
#define NK_CURVED_NEONHALF_H
|
|
31
|
+
|
|
32
|
+
#if NK_TARGET_ARM_
|
|
33
|
+
#if NK_TARGET_NEONHALF
|
|
34
|
+
|
|
35
|
+
#include "numkong/types.h"
|
|
36
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
37
|
+
#include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`
|
|
38
|
+
|
|
39
|
+
#if defined(__cplusplus)
|
|
40
|
+
extern "C" {
|
|
41
|
+
#endif
|
|
42
|
+
|
|
43
|
+
#if defined(__clang__)
|
|
44
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
45
|
+
#elif defined(__GNUC__)
|
|
46
|
+
#pragma GCC push_options
|
|
47
|
+
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
48
|
+
#endif
|
|
49
|
+
|
|
50
|
+
NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
51
|
+
nk_f32_t *result) {
|
|
52
|
+
nk_f32_t outer_sum = 0;
|
|
53
|
+
|
|
54
|
+
// Process rows of the matrix
|
|
55
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
56
|
+
nk_f16_t const *c_row = c + row * n;
|
|
57
|
+
|
|
58
|
+
// Load a[row] as f32
|
|
59
|
+
nk_f32_t a_row;
|
|
60
|
+
nk_f16_to_f32_serial(a + row, &a_row);
|
|
61
|
+
|
|
62
|
+
// Compute inner sum
|
|
63
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
64
|
+
nk_size_t column = 0;
|
|
65
|
+
|
|
66
|
+
// Process 4 elements at a time
|
|
67
|
+
for (; column + 4 <= n; column += 4) {
|
|
68
|
+
float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
|
|
69
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
|
|
70
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, b_f32x4);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// Reduce SIMD accumulator
|
|
74
|
+
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
75
|
+
|
|
76
|
+
// Handle tail elements with scalar code
|
|
77
|
+
for (; column < n; ++column) {
|
|
78
|
+
nk_f32_t b_val, c_val;
|
|
79
|
+
nk_f16_to_f32_serial(b + column, &b_val);
|
|
80
|
+
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
81
|
+
inner_sum += c_val * b_val;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Multiply by a[row] and accumulate
|
|
85
|
+
outer_sum += a_row * inner_sum;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
*result = outer_sum;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
92
|
+
nk_f32_t *result) {
|
|
93
|
+
nk_f32_t outer_sum = 0;
|
|
94
|
+
|
|
95
|
+
// Process rows of the matrix
|
|
96
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
97
|
+
nk_f16_t const *c_row = c + row * n;
|
|
98
|
+
|
|
99
|
+
// Compute diff_row = a[row] - b[row] in f32
|
|
100
|
+
nk_f32_t a_row, b_row;
|
|
101
|
+
nk_f16_to_f32_serial(a + row, &a_row);
|
|
102
|
+
nk_f16_to_f32_serial(b + row, &b_row);
|
|
103
|
+
nk_f32_t diff_row = a_row - b_row;
|
|
104
|
+
|
|
105
|
+
// Compute inner sum
|
|
106
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
107
|
+
nk_size_t column = 0;
|
|
108
|
+
|
|
109
|
+
// Process 4 elements at a time
|
|
110
|
+
for (; column + 4 <= n; column += 4) {
|
|
111
|
+
float32x4_t a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(a + column)));
|
|
112
|
+
float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
|
|
113
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
|
|
114
|
+
float32x4_t diff_column_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
|
|
115
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_column_f32x4);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// Reduce SIMD accumulator
|
|
119
|
+
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
120
|
+
|
|
121
|
+
// Handle tail elements with scalar code
|
|
122
|
+
for (; column < n; ++column) {
|
|
123
|
+
nk_f32_t a_val, b_val, c_val;
|
|
124
|
+
nk_f16_to_f32_serial(a + column, &a_val);
|
|
125
|
+
nk_f16_to_f32_serial(b + column, &b_val);
|
|
126
|
+
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
127
|
+
inner_sum += c_val * (a_val - b_val);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Multiply by diff_row and accumulate
|
|
131
|
+
outer_sum += diff_row * inner_sum;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
nk_f32_t quadratic = outer_sum;
|
|
135
|
+
*result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
|
|
139
|
+
nk_size_t n, nk_f32c_t *results) {
|
|
140
|
+
nk_f32_t outer_sum_real = 0;
|
|
141
|
+
nk_f32_t outer_sum_imag = 0;
|
|
142
|
+
|
|
143
|
+
// Process rows of the matrix
|
|
144
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
145
|
+
nk_f16c_t const *c_row = c_pairs + row * n;
|
|
146
|
+
|
|
147
|
+
// Load a[row] complex value
|
|
148
|
+
nk_f32_t a_real, a_imag;
|
|
149
|
+
nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
|
|
150
|
+
nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
|
|
151
|
+
|
|
152
|
+
// Compute inner sum
|
|
153
|
+
float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
|
|
154
|
+
float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
|
|
155
|
+
nk_size_t column = 0;
|
|
156
|
+
|
|
157
|
+
// Process 4 complex pairs at a time using deinterleaved loads
|
|
158
|
+
for (; column + 4 <= n; column += 4) {
|
|
159
|
+
// Deinterleave real/imaginary using vld2_s16 pattern from dot/neonhalf.h
|
|
160
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + column));
|
|
161
|
+
int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_row + column));
|
|
162
|
+
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
163
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
164
|
+
float32x4_t c_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[0]));
|
|
165
|
+
float32x4_t c_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[1]));
|
|
166
|
+
|
|
167
|
+
// Complex multiply
|
|
168
|
+
inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
|
|
169
|
+
inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
|
|
170
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
|
|
171
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// Reduce SIMD accumulators
|
|
175
|
+
nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
|
|
176
|
+
nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
|
|
177
|
+
|
|
178
|
+
// Handle tail elements with scalar code
|
|
179
|
+
for (; column < n; ++column) {
|
|
180
|
+
nk_f32_t b_real, b_imag, c_real, c_imag;
|
|
181
|
+
nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
|
|
182
|
+
nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
|
|
183
|
+
nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
|
|
184
|
+
nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
|
|
185
|
+
|
|
186
|
+
// Complex multiply
|
|
187
|
+
inner_sum_real += c_real * b_real - c_imag * b_imag;
|
|
188
|
+
inner_sum_imag += c_real * b_imag + c_imag * b_real;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// Complex multiply
|
|
192
|
+
outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
|
|
193
|
+
outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
results->real = outer_sum_real;
|
|
197
|
+
results->imag = outer_sum_imag;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
#if defined(__clang__)
|
|
201
|
+
#pragma clang attribute pop
|
|
202
|
+
#elif defined(__GNUC__)
|
|
203
|
+
#pragma GCC pop_options
|
|
204
|
+
#endif
|
|
205
|
+
|
|
206
|
+
#if defined(__cplusplus)
|
|
207
|
+
} // extern "C"
|
|
208
|
+
#endif
|
|
209
|
+
|
|
210
|
+
#endif // NK_TARGET_NEONHALF
|
|
211
|
+
#endif // NK_TARGET_ARM_
|
|
212
|
+
#endif // NK_CURVED_NEONHALF_H
|