numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,818 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for NEON.
|
|
3
|
+
* @file include/numkong/dot/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_neon_instructions NEON Dot Product Instructions
|
|
10
|
+
*
|
|
11
|
+
* Key NEON instructions for dot products:
|
|
12
|
+
*
|
|
13
|
+
* Intrinsic Instruction Latency Throughput
|
|
14
|
+
* A76 M4+/V1+/Oryon
|
|
15
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
16
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
|
|
17
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
|
|
18
|
+
* vaddvq_f32 FADDP+FADDP (reduce) 5cy 1/cy 1/cy
|
|
19
|
+
* vaddvq_f64 FADDP (V.2D to scalar) 3cy 1/cy 1/cy
|
|
20
|
+
* vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy 2/cy
|
|
21
|
+
* vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
|
|
22
|
+
*
|
|
23
|
+
* FMA throughput doubles on cores with 4 SIMD pipes (Apple M4+, Graviton3+, Oryon), but
|
|
24
|
+
* horizontal reductions remain at 1/cy on all cores and become the main bottleneck.
|
|
25
|
+
*
|
|
26
|
+
* For f32 dot products, we upcast to f64 for accumulation to preserve precision and
|
|
27
|
+
* avoid catastrophic cancellation in large-magnitude sums.
|
|
28
|
+
*
|
|
29
|
+
* @section dot_neon_stateful Stateful Streaming Logic
|
|
30
|
+
*
|
|
31
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
32
|
+
* `NK_INTERNAL` functions:
|
|
33
|
+
*
|
|
34
|
+
* - nk_dot_f32x2 state for f32 inputs with double-precision accumulation,
|
|
35
|
+
* - nk_dot_f64x2 state with Dot2 stable dot-products for f64 inputs.
|
|
36
|
+
*
|
|
37
|
+
* @code{c}
|
|
38
|
+
* nk_dot_f32x2_state_neon_t state_first, state_second, state_third, state_fourth;
|
|
39
|
+
* float32x2_t query_f32x2, target_first_f32x2, target_second_f32x2, target_third_f32x2, target_fourth_f32x2;
|
|
40
|
+
* nk_dot_f32x2_init_neon(&state_first);
|
|
41
|
+
* nk_dot_f32x2_init_neon(&state_second);
|
|
42
|
+
* nk_dot_f32x2_init_neon(&state_third);
|
|
43
|
+
* nk_dot_f32x2_init_neon(&state_fourth);
|
|
44
|
+
* for (nk_size_t idx = 0; idx + 2 <= depth; idx += 2) {
|
|
45
|
+
* query_f32x2 = vld1_f32(query_ptr + idx);
|
|
46
|
+
* target_first_f32x2 = vld1_f32(target_first_ptr + idx);
|
|
47
|
+
* target_second_f32x2 = vld1_f32(target_second_ptr + idx);
|
|
48
|
+
* target_third_f32x2 = vld1_f32(target_third_ptr + idx);
|
|
49
|
+
* target_fourth_f32x2 = vld1_f32(target_fourth_ptr + idx);
|
|
50
|
+
* nk_dot_f32x2_update_neon(&state_first, query_f32x2, target_first_f32x2, idx, 2);
|
|
51
|
+
* nk_dot_f32x2_update_neon(&state_second, query_f32x2, target_second_f32x2, idx, 2);
|
|
52
|
+
* nk_dot_f32x2_update_neon(&state_third, query_f32x2, target_third_f32x2, idx, 2);
|
|
53
|
+
* nk_dot_f32x2_update_neon(&state_fourth, query_f32x2, target_fourth_f32x2, idx, 2);
|
|
54
|
+
* }
|
|
55
|
+
* float32x4_t results_f32x4;
|
|
56
|
+
* nk_dot_f32x2_finalize_neon(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
|
|
57
|
+
* @endcode
|
|
58
|
+
*
|
|
59
|
+
* For f64 inputs, Dot2 compensated summation provides numerical stability:
|
|
60
|
+
*
|
|
61
|
+
* @code{c}
|
|
62
|
+
* nk_dot_f64x2_state_neon_t state_first, state_second, state_third, state_fourth;
|
|
63
|
+
* float64x2_t query_f64x2, target_first_f64x2, target_second_f64x2, target_third_f64x2, target_fourth_f64x2;
|
|
64
|
+
* nk_dot_f64x2_init_neon(&state_first);
|
|
65
|
+
* nk_dot_f64x2_init_neon(&state_second);
|
|
66
|
+
* nk_dot_f64x2_init_neon(&state_third);
|
|
67
|
+
* nk_dot_f64x2_init_neon(&state_fourth);
|
|
68
|
+
* for (nk_size_t idx = 0; idx + 2 <= depth; idx += 2) {
|
|
69
|
+
* query_f64x2 = vld1q_f64(query_ptr + idx);
|
|
70
|
+
* target_first_f64x2 = vld1q_f64(target_first_ptr + idx);
|
|
71
|
+
* target_second_f64x2 = vld1q_f64(target_second_ptr + idx);
|
|
72
|
+
* target_third_f64x2 = vld1q_f64(target_third_ptr + idx);
|
|
73
|
+
* target_fourth_f64x2 = vld1q_f64(target_fourth_ptr + idx);
|
|
74
|
+
* nk_dot_f64x2_update_neon(&state_first, query_f64x2, target_first_f64x2, idx, 2);
|
|
75
|
+
* nk_dot_f64x2_update_neon(&state_second, query_f64x2, target_second_f64x2, idx, 2);
|
|
76
|
+
* nk_dot_f64x2_update_neon(&state_third, query_f64x2, target_third_f64x2, idx, 2);
|
|
77
|
+
* nk_dot_f64x2_update_neon(&state_fourth, query_f64x2, target_fourth_f64x2, idx, 2);
|
|
78
|
+
* }
|
|
79
|
+
* float64x4_t results_f64x4;
|
|
80
|
+
* nk_dot_f64x2_finalize_neon(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f64x4);
|
|
81
|
+
* @endcode
|
|
82
|
+
*/
|
|
83
|
+
#ifndef NK_DOT_NEON_H
|
|
84
|
+
#define NK_DOT_NEON_H
|
|
85
|
+
|
|
86
|
+
#if NK_TARGET_ARM_
|
|
87
|
+
#if NK_TARGET_NEON
|
|
88
|
+
|
|
89
|
+
#include "numkong/cast/neon.h" // `nk_e4m3x8_to_f16x8_neon_`
|
|
90
|
+
|
|
91
|
+
#if defined(__cplusplus)
|
|
92
|
+
extern "C" {
|
|
93
|
+
#endif
|
|
94
|
+
|
|
95
|
+
#if defined(__clang__)
|
|
96
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
|
|
97
|
+
#elif defined(__GNUC__)
|
|
98
|
+
#pragma GCC push_options
|
|
99
|
+
#pragma GCC target("arch=armv8-a+simd")
|
|
100
|
+
#endif
|
|
101
|
+
|
|
102
|
+
/** @brief Compensated horizontal sum of 2 f64 lanes via TwoSum. */
|
|
103
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_neon_(float64x2_t sum_f64x2, float64x2_t compensation_f64x2) {
|
|
104
|
+
// TwoSum merge of sum + compensation (2-wide)
|
|
105
|
+
float64x2_t tentative_sum_f64x2 = vaddq_f64(sum_f64x2, compensation_f64x2);
|
|
106
|
+
float64x2_t virtual_addend_f64x2 = vsubq_f64(tentative_sum_f64x2, sum_f64x2);
|
|
107
|
+
float64x2_t rounding_error_f64x2 = vaddq_f64(
|
|
108
|
+
vsubq_f64(sum_f64x2, vsubq_f64(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
109
|
+
vsubq_f64(compensation_f64x2, virtual_addend_f64x2));
|
|
110
|
+
// Scalar TwoSum 2→1
|
|
111
|
+
nk_f64_t lower_sum = vgetq_lane_f64(tentative_sum_f64x2, 0);
|
|
112
|
+
nk_f64_t upper_sum = vgetq_lane_f64(tentative_sum_f64x2, 1);
|
|
113
|
+
nk_f64_t lower_error = vgetq_lane_f64(rounding_error_f64x2, 0);
|
|
114
|
+
nk_f64_t upper_error = vgetq_lane_f64(rounding_error_f64x2, 1);
|
|
115
|
+
nk_f64_t tentative_sum = lower_sum + upper_sum;
|
|
116
|
+
nk_f64_t virtual_addend = tentative_sum - lower_sum;
|
|
117
|
+
nk_f64_t rounding_error = (lower_sum - (tentative_sum - virtual_addend)) + (upper_sum - virtual_addend);
|
|
118
|
+
return tentative_sum + (lower_error + upper_error + rounding_error);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
#pragma region - Traditional Floats
|
|
122
|
+
|
|
123
|
+
NK_PUBLIC void nk_dot_f32_neon(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
124
|
+
nk_f64_t *result) {
|
|
125
|
+
// Upcast f32 to f64 for accumulation (2 f32s per iteration, avoids slow vget_low/high)
|
|
126
|
+
float64x2_t sum_f64x2 = vdupq_n_f64(0);
|
|
127
|
+
nk_size_t idx_scalars = 0;
|
|
128
|
+
for (; idx_scalars + 2 <= count_scalars; idx_scalars += 2) {
|
|
129
|
+
float32x2_t a_f32x2 = vld1_f32(a_scalars + idx_scalars);
|
|
130
|
+
float32x2_t b_f32x2 = vld1_f32(b_scalars + idx_scalars);
|
|
131
|
+
float64x2_t a_f64x2 = vcvt_f64_f32(a_f32x2);
|
|
132
|
+
float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
|
|
133
|
+
sum_f64x2 = vfmaq_f64(sum_f64x2, a_f64x2, b_f64x2);
|
|
134
|
+
}
|
|
135
|
+
nk_f64_t sum_f64 = vaddvq_f64(sum_f64x2);
|
|
136
|
+
for (; idx_scalars < count_scalars; ++idx_scalars)
|
|
137
|
+
sum_f64 += (nk_f64_t)a_scalars[idx_scalars] * (nk_f64_t)b_scalars[idx_scalars];
|
|
138
|
+
*result = sum_f64;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
NK_PUBLIC void nk_dot_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
142
|
+
nk_f64c_t *result) {
|
|
143
|
+
// Upcast f32 to f64 for accumulation (2 complex pairs per iteration, avoids slow vget_low/high)
|
|
144
|
+
float64x2_t sum_real_f64x2 = vdupq_n_f64(0);
|
|
145
|
+
float64x2_t sum_imag_f64x2 = vdupq_n_f64(0);
|
|
146
|
+
nk_size_t idx_pairs = 0;
|
|
147
|
+
// ARMv8.3-A FCMLA (`vcmlaq_rot0/rot90_f32`) was benchmarked as an alternative to the
|
|
148
|
+
// deinterleave+4FMA pattern below. FCMLA processes only 2 complex pairs per iteration
|
|
149
|
+
// (interleaved 128-bit operands, 2x `vcmlaq`), while `vld2_f32` deinterleaves 2 pairs
|
|
150
|
+
// with 4 independent FMA instructions that fully utilize M4's 4 SIMD pipes. Result on
|
|
151
|
+
// Apple M4 at n=4096: manual f32 39.7 GiB/s, FCMLA 17.1 GiB/s (2.3x slower).
|
|
152
|
+
// The f64 upcast here trades throughput for precision — FCMLA offers neither advantage.
|
|
153
|
+
for (; idx_pairs + 2 <= count_pairs; idx_pairs += 2) {
|
|
154
|
+
// Unpack 2 complex pairs into real and imaginary parts:
|
|
155
|
+
float32x2x2_t a_f32x2x2 = vld2_f32((nk_f32_t const *)(a_pairs + idx_pairs));
|
|
156
|
+
float32x2x2_t b_f32x2x2 = vld2_f32((nk_f32_t const *)(b_pairs + idx_pairs));
|
|
157
|
+
// Upcast to f64
|
|
158
|
+
float64x2_t a_real_f64x2 = vcvt_f64_f32(a_f32x2x2.val[0]);
|
|
159
|
+
float64x2_t a_imag_f64x2 = vcvt_f64_f32(a_f32x2x2.val[1]);
|
|
160
|
+
float64x2_t b_real_f64x2 = vcvt_f64_f32(b_f32x2x2.val[0]);
|
|
161
|
+
float64x2_t b_imag_f64x2 = vcvt_f64_f32(b_f32x2x2.val[1]);
|
|
162
|
+
// Compute the dot product: real = aᵣ × bᵣ - aᵢ × bᵢ, imag = aᵣ × bᵢ + aᵢ × bᵣ
|
|
163
|
+
sum_real_f64x2 = vfmaq_f64(sum_real_f64x2, a_real_f64x2, b_real_f64x2);
|
|
164
|
+
sum_real_f64x2 = vfmsq_f64(sum_real_f64x2, a_imag_f64x2, b_imag_f64x2);
|
|
165
|
+
sum_imag_f64x2 = vfmaq_f64(sum_imag_f64x2, a_real_f64x2, b_imag_f64x2);
|
|
166
|
+
sum_imag_f64x2 = vfmaq_f64(sum_imag_f64x2, a_imag_f64x2, b_real_f64x2);
|
|
167
|
+
}
|
|
168
|
+
// Reduce horizontal sums:
|
|
169
|
+
nk_f64_t sum_real_f64 = vaddvq_f64(sum_real_f64x2);
|
|
170
|
+
nk_f64_t sum_imag_f64 = vaddvq_f64(sum_imag_f64x2);
|
|
171
|
+
// Handle the tail:
|
|
172
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
173
|
+
nk_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs];
|
|
174
|
+
nk_f64_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag;
|
|
175
|
+
sum_real_f64 += ar * br - ai * bi;
|
|
176
|
+
sum_imag_f64 += ar * bi + ai * br;
|
|
177
|
+
}
|
|
178
|
+
result->real = sum_real_f64;
|
|
179
|
+
result->imag = sum_imag_f64;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
NK_PUBLIC void nk_vdot_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
183
|
+
nk_f64c_t *result) {
|
|
184
|
+
// Upcast f32 to f64 for accumulation (2 complex pairs per iteration, avoids slow vget_low/high)
|
|
185
|
+
float64x2_t sum_real_f64x2 = vdupq_n_f64(0);
|
|
186
|
+
float64x2_t sum_imag_f64x2 = vdupq_n_f64(0);
|
|
187
|
+
nk_size_t idx_pairs = 0;
|
|
188
|
+
for (; idx_pairs + 2 <= count_pairs; idx_pairs += 2) {
|
|
189
|
+
// Unpack 2 complex pairs into real and imaginary parts:
|
|
190
|
+
float32x2x2_t a_f32x2x2 = vld2_f32((nk_f32_t const *)(a_pairs + idx_pairs));
|
|
191
|
+
float32x2x2_t b_f32x2x2 = vld2_f32((nk_f32_t const *)(b_pairs + idx_pairs));
|
|
192
|
+
// Upcast to f64
|
|
193
|
+
float64x2_t a_real_f64x2 = vcvt_f64_f32(a_f32x2x2.val[0]);
|
|
194
|
+
float64x2_t a_imag_f64x2 = vcvt_f64_f32(a_f32x2x2.val[1]);
|
|
195
|
+
float64x2_t b_real_f64x2 = vcvt_f64_f32(b_f32x2x2.val[0]);
|
|
196
|
+
float64x2_t b_imag_f64x2 = vcvt_f64_f32(b_f32x2x2.val[1]);
|
|
197
|
+
// Compute conjugate dot product: real = aᵣ × bᵣ + aᵢ × bᵢ, imag = aᵣ × bᵢ - aᵢ × bᵣ
|
|
198
|
+
sum_real_f64x2 = vfmaq_f64(sum_real_f64x2, a_real_f64x2, b_real_f64x2);
|
|
199
|
+
sum_real_f64x2 = vfmaq_f64(sum_real_f64x2, a_imag_f64x2, b_imag_f64x2);
|
|
200
|
+
sum_imag_f64x2 = vfmaq_f64(sum_imag_f64x2, a_real_f64x2, b_imag_f64x2);
|
|
201
|
+
sum_imag_f64x2 = vfmsq_f64(sum_imag_f64x2, a_imag_f64x2, b_real_f64x2);
|
|
202
|
+
}
|
|
203
|
+
// Reduce horizontal sums:
|
|
204
|
+
nk_f64_t sum_real_f64 = vaddvq_f64(sum_real_f64x2);
|
|
205
|
+
nk_f64_t sum_imag_f64 = vaddvq_f64(sum_imag_f64x2);
|
|
206
|
+
// Handle the tail:
|
|
207
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
208
|
+
nk_f32c_t a_pair = a_pairs[idx_pairs], b_pair = b_pairs[idx_pairs];
|
|
209
|
+
nk_f64_t ar = a_pair.real, ai = a_pair.imag, br = b_pair.real, bi = b_pair.imag;
|
|
210
|
+
sum_real_f64 += ar * br + ai * bi;
|
|
211
|
+
sum_imag_f64 += ar * bi - ai * br;
|
|
212
|
+
}
|
|
213
|
+
result->real = sum_real_f64;
|
|
214
|
+
result->imag = sum_imag_f64;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/**
|
|
218
|
+
* @brief Running state for 64-bit dot accumulation over f32 scalars on NEON.
|
|
219
|
+
*
|
|
220
|
+
* Processes 2 f32 values at a time, upcasting to f64 for accumulation to avoid
|
|
221
|
+
* catastrophic cancellation in long reductions.
|
|
222
|
+
*/
|
|
223
|
+
typedef struct nk_dot_f32x2_state_neon_t {
|
|
224
|
+
float64x2_t sum_f64x2;
|
|
225
|
+
} nk_dot_f32x2_state_neon_t;
|
|
226
|
+
|
|
227
|
+
NK_INTERNAL void nk_dot_f32x2_init_neon(nk_dot_f32x2_state_neon_t *state) { state->sum_f64x2 = vdupq_n_f64(0); }
|
|
228
|
+
|
|
229
|
+
NK_INTERNAL void nk_dot_f32x2_update_neon(nk_dot_f32x2_state_neon_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
230
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
231
|
+
nk_unused_(depth_offset);
|
|
232
|
+
nk_unused_(active_dimensions);
|
|
233
|
+
// Upcast 2 f32s to f64s for high-precision accumulation
|
|
234
|
+
float32x2_t a_f32x2 = vreinterpret_f32_u32(a.u32x2);
|
|
235
|
+
float32x2_t b_f32x2 = vreinterpret_f32_u32(b.u32x2);
|
|
236
|
+
float64x2_t a_f64x2 = vcvt_f64_f32(a_f32x2);
|
|
237
|
+
float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
|
|
238
|
+
state->sum_f64x2 = vfmaq_f64(state->sum_f64x2, a_f64x2, b_f64x2);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
NK_INTERNAL void nk_dot_f32x2_finalize_neon( //
|
|
242
|
+
nk_dot_f32x2_state_neon_t const *state_a, nk_dot_f32x2_state_neon_t const *state_b, //
|
|
243
|
+
nk_dot_f32x2_state_neon_t const *state_c, nk_dot_f32x2_state_neon_t const *state_d, //
|
|
244
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
245
|
+
nk_unused_(total_dimensions);
|
|
246
|
+
result->f64s[0] = vaddvq_f64(state_a->sum_f64x2);
|
|
247
|
+
result->f64s[1] = vaddvq_f64(state_b->sum_f64x2);
|
|
248
|
+
result->f64s[2] = vaddvq_f64(state_c->sum_f64x2);
|
|
249
|
+
result->f64s[3] = vaddvq_f64(state_d->sum_f64x2);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
NK_PUBLIC void nk_dot_f64_neon(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
253
|
+
nk_f64_t *result) {
|
|
254
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
|
|
255
|
+
float64x2_t sum_f64x2 = vdupq_n_f64(0);
|
|
256
|
+
float64x2_t compensation_f64x2 = vdupq_n_f64(0);
|
|
257
|
+
float64x2_t a_f64x2, b_f64x2;
|
|
258
|
+
|
|
259
|
+
nk_dot_f64_neon_cycle:
|
|
260
|
+
if (count_scalars < 2) {
|
|
261
|
+
nk_b128_vec_t a_tail, b_tail;
|
|
262
|
+
nk_partial_load_b64x2_serial_(a_scalars, &a_tail, count_scalars);
|
|
263
|
+
nk_partial_load_b64x2_serial_(b_scalars, &b_tail, count_scalars);
|
|
264
|
+
a_f64x2 = a_tail.f64x2;
|
|
265
|
+
b_f64x2 = b_tail.f64x2;
|
|
266
|
+
count_scalars = 0;
|
|
267
|
+
}
|
|
268
|
+
else {
|
|
269
|
+
a_f64x2 = vld1q_f64(a_scalars);
|
|
270
|
+
b_f64x2 = vld1q_f64(b_scalars);
|
|
271
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
// TwoProd: h = a × b, r = fma(a, b, -h) captures the rounding error
|
|
275
|
+
float64x2_t product_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
276
|
+
float64x2_t product_error_f64x2 = vnegq_f64(vfmsq_f64(product_f64x2, a_f64x2, b_f64x2));
|
|
277
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
278
|
+
float64x2_t tentative_sum_f64x2 = vaddq_f64(sum_f64x2, product_f64x2);
|
|
279
|
+
float64x2_t virtual_addend_f64x2 = vsubq_f64(tentative_sum_f64x2, sum_f64x2);
|
|
280
|
+
float64x2_t sum_error_f64x2 = vaddq_f64(vsubq_f64(sum_f64x2, vsubq_f64(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
281
|
+
vsubq_f64(product_f64x2, virtual_addend_f64x2));
|
|
282
|
+
// Update: sum = t, compensation += q + r
|
|
283
|
+
sum_f64x2 = tentative_sum_f64x2;
|
|
284
|
+
compensation_f64x2 = vaddq_f64(compensation_f64x2, vaddq_f64(sum_error_f64x2, product_error_f64x2));
|
|
285
|
+
|
|
286
|
+
if (count_scalars) goto nk_dot_f64_neon_cycle;
|
|
287
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
288
|
+
*result = nk_dot_stable_sum_f64x2_neon_(sum_f64x2, compensation_f64x2);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
NK_PUBLIC void nk_dot_f64c_neon(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
292
|
+
nk_f64c_t *result) {
|
|
293
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated complex dot product
|
|
294
|
+
float64x2_t sum_real_f64x2 = vdupq_n_f64(0);
|
|
295
|
+
float64x2_t sum_imag_f64x2 = vdupq_n_f64(0);
|
|
296
|
+
float64x2_t compensation_real_f64x2 = vdupq_n_f64(0);
|
|
297
|
+
float64x2_t compensation_imag_f64x2 = vdupq_n_f64(0);
|
|
298
|
+
float64x2_t a_real_f64x2, a_imag_f64x2, b_real_f64x2, b_imag_f64x2;
|
|
299
|
+
|
|
300
|
+
nk_dot_f64c_neon_cycle:
|
|
301
|
+
if (count_pairs < 2) {
|
|
302
|
+
nk_b128_vec_t a_tail, b_tail;
|
|
303
|
+
nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
|
|
304
|
+
nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
|
|
305
|
+
float64x2_t zeros = vdupq_n_f64(0);
|
|
306
|
+
a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros);
|
|
307
|
+
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros);
|
|
308
|
+
b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros);
|
|
309
|
+
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros);
|
|
310
|
+
count_pairs = 0;
|
|
311
|
+
}
|
|
312
|
+
else {
|
|
313
|
+
float64x2x2_t a_f64x2x2 = vld2q_f64((nk_f64_t const *)a_pairs);
|
|
314
|
+
float64x2x2_t b_f64x2x2 = vld2q_f64((nk_f64_t const *)b_pairs);
|
|
315
|
+
a_real_f64x2 = a_f64x2x2.val[0];
|
|
316
|
+
a_imag_f64x2 = a_f64x2x2.val[1];
|
|
317
|
+
b_real_f64x2 = b_f64x2x2.val[0];
|
|
318
|
+
b_imag_f64x2 = b_f64x2x2.val[1];
|
|
319
|
+
a_pairs += 2, b_pairs += 2, count_pairs -= 2;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
// Real part: aᵣ × bᵣ - aᵢ × bᵢ (using TwoProd and TwoSum)
|
|
323
|
+
// First term: +aᵣ × bᵣ
|
|
324
|
+
float64x2_t product_rr_f64x2 = vmulq_f64(a_real_f64x2, b_real_f64x2);
|
|
325
|
+
float64x2_t error_rr_f64x2 = vnegq_f64(vfmsq_f64(product_rr_f64x2, a_real_f64x2, b_real_f64x2));
|
|
326
|
+
float64x2_t tentative_sum_real_f64x2 = vaddq_f64(sum_real_f64x2, product_rr_f64x2);
|
|
327
|
+
float64x2_t virtual_addend_real_f64x2 = vsubq_f64(tentative_sum_real_f64x2, sum_real_f64x2);
|
|
328
|
+
float64x2_t error_sum_real_f64x2 = vaddq_f64(
|
|
329
|
+
vsubq_f64(sum_real_f64x2, vsubq_f64(tentative_sum_real_f64x2, virtual_addend_real_f64x2)),
|
|
330
|
+
vsubq_f64(product_rr_f64x2, virtual_addend_real_f64x2));
|
|
331
|
+
sum_real_f64x2 = tentative_sum_real_f64x2;
|
|
332
|
+
compensation_real_f64x2 = vaddq_f64(compensation_real_f64x2, vaddq_f64(error_sum_real_f64x2, error_rr_f64x2));
|
|
333
|
+
// Second term: -aᵢ × bᵢ (negate product and error, then standard TwoSum)
|
|
334
|
+
float64x2_t product_ii_f64x2 = vmulq_f64(a_imag_f64x2, b_imag_f64x2);
|
|
335
|
+
float64x2_t error_ii_f64x2 = vnegq_f64(vfmsq_f64(product_ii_f64x2, a_imag_f64x2, b_imag_f64x2));
|
|
336
|
+
float64x2_t neg_product_ii_f64x2 = vnegq_f64(product_ii_f64x2);
|
|
337
|
+
float64x2_t neg_error_ii_f64x2 = vnegq_f64(error_ii_f64x2);
|
|
338
|
+
tentative_sum_real_f64x2 = vaddq_f64(sum_real_f64x2, neg_product_ii_f64x2);
|
|
339
|
+
virtual_addend_real_f64x2 = vsubq_f64(tentative_sum_real_f64x2, sum_real_f64x2);
|
|
340
|
+
error_sum_real_f64x2 = vaddq_f64(
|
|
341
|
+
vsubq_f64(sum_real_f64x2, vsubq_f64(tentative_sum_real_f64x2, virtual_addend_real_f64x2)),
|
|
342
|
+
vsubq_f64(neg_product_ii_f64x2, virtual_addend_real_f64x2));
|
|
343
|
+
sum_real_f64x2 = tentative_sum_real_f64x2;
|
|
344
|
+
compensation_real_f64x2 = vaddq_f64(compensation_real_f64x2, vaddq_f64(error_sum_real_f64x2, neg_error_ii_f64x2));
|
|
345
|
+
|
|
346
|
+
// Imag part: aᵣ × bᵢ + aᵢ × bᵣ (using TwoProd and TwoSum)
|
|
347
|
+
// First term: +aᵣ × bᵢ
|
|
348
|
+
float64x2_t product_ri_f64x2 = vmulq_f64(a_real_f64x2, b_imag_f64x2);
|
|
349
|
+
float64x2_t error_ri_f64x2 = vnegq_f64(vfmsq_f64(product_ri_f64x2, a_real_f64x2, b_imag_f64x2));
|
|
350
|
+
float64x2_t tentative_sum_imag_f64x2 = vaddq_f64(sum_imag_f64x2, product_ri_f64x2);
|
|
351
|
+
float64x2_t virtual_addend_imag_f64x2 = vsubq_f64(tentative_sum_imag_f64x2, sum_imag_f64x2);
|
|
352
|
+
float64x2_t error_sum_imag_f64x2 = vaddq_f64(
|
|
353
|
+
vsubq_f64(sum_imag_f64x2, vsubq_f64(tentative_sum_imag_f64x2, virtual_addend_imag_f64x2)),
|
|
354
|
+
vsubq_f64(product_ri_f64x2, virtual_addend_imag_f64x2));
|
|
355
|
+
sum_imag_f64x2 = tentative_sum_imag_f64x2;
|
|
356
|
+
compensation_imag_f64x2 = vaddq_f64(compensation_imag_f64x2, vaddq_f64(error_sum_imag_f64x2, error_ri_f64x2));
|
|
357
|
+
// Second term: +aᵢ × bᵣ
|
|
358
|
+
float64x2_t product_ir_f64x2 = vmulq_f64(a_imag_f64x2, b_real_f64x2);
|
|
359
|
+
float64x2_t error_ir_f64x2 = vnegq_f64(vfmsq_f64(product_ir_f64x2, a_imag_f64x2, b_real_f64x2));
|
|
360
|
+
tentative_sum_imag_f64x2 = vaddq_f64(sum_imag_f64x2, product_ir_f64x2);
|
|
361
|
+
virtual_addend_imag_f64x2 = vsubq_f64(tentative_sum_imag_f64x2, sum_imag_f64x2);
|
|
362
|
+
error_sum_imag_f64x2 = vaddq_f64(
|
|
363
|
+
vsubq_f64(sum_imag_f64x2, vsubq_f64(tentative_sum_imag_f64x2, virtual_addend_imag_f64x2)),
|
|
364
|
+
vsubq_f64(product_ir_f64x2, virtual_addend_imag_f64x2));
|
|
365
|
+
sum_imag_f64x2 = tentative_sum_imag_f64x2;
|
|
366
|
+
compensation_imag_f64x2 = vaddq_f64(compensation_imag_f64x2, vaddq_f64(error_sum_imag_f64x2, error_ir_f64x2));
|
|
367
|
+
|
|
368
|
+
if (count_pairs) goto nk_dot_f64c_neon_cycle;
|
|
369
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
370
|
+
result->real = nk_dot_stable_sum_f64x2_neon_(sum_real_f64x2, compensation_real_f64x2);
|
|
371
|
+
result->imag = nk_dot_stable_sum_f64x2_neon_(sum_imag_f64x2, compensation_imag_f64x2);
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
NK_PUBLIC void nk_vdot_f64c_neon(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
375
|
+
nk_f64c_t *result) {
|
|
376
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated conjugate dot product
|
|
377
|
+
float64x2_t sum_real_f64x2 = vdupq_n_f64(0);
|
|
378
|
+
float64x2_t sum_imag_f64x2 = vdupq_n_f64(0);
|
|
379
|
+
float64x2_t compensation_real_f64x2 = vdupq_n_f64(0);
|
|
380
|
+
float64x2_t compensation_imag_f64x2 = vdupq_n_f64(0);
|
|
381
|
+
float64x2_t a_real_f64x2, a_imag_f64x2, b_real_f64x2, b_imag_f64x2;
|
|
382
|
+
|
|
383
|
+
nk_vdot_f64c_neon_cycle:
|
|
384
|
+
if (count_pairs < 2) {
|
|
385
|
+
nk_b128_vec_t a_tail, b_tail;
|
|
386
|
+
nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
|
|
387
|
+
nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
|
|
388
|
+
float64x2_t zeros = vdupq_n_f64(0);
|
|
389
|
+
a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros);
|
|
390
|
+
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros);
|
|
391
|
+
b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros);
|
|
392
|
+
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros);
|
|
393
|
+
count_pairs = 0;
|
|
394
|
+
}
|
|
395
|
+
else {
|
|
396
|
+
float64x2x2_t a_f64x2x2 = vld2q_f64((nk_f64_t const *)a_pairs);
|
|
397
|
+
float64x2x2_t b_f64x2x2 = vld2q_f64((nk_f64_t const *)b_pairs);
|
|
398
|
+
a_real_f64x2 = a_f64x2x2.val[0];
|
|
399
|
+
a_imag_f64x2 = a_f64x2x2.val[1];
|
|
400
|
+
b_real_f64x2 = b_f64x2x2.val[0];
|
|
401
|
+
b_imag_f64x2 = b_f64x2x2.val[1];
|
|
402
|
+
a_pairs += 2, b_pairs += 2, count_pairs -= 2;
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
// Real part: aᵣ × bᵣ + aᵢ × bᵢ (using TwoProd and TwoSum)
|
|
406
|
+
// First term: +aᵣ × bᵣ
|
|
407
|
+
float64x2_t product_rr_f64x2 = vmulq_f64(a_real_f64x2, b_real_f64x2);
|
|
408
|
+
float64x2_t error_rr_f64x2 = vnegq_f64(vfmsq_f64(product_rr_f64x2, a_real_f64x2, b_real_f64x2));
|
|
409
|
+
float64x2_t tentative_sum_real_f64x2 = vaddq_f64(sum_real_f64x2, product_rr_f64x2);
|
|
410
|
+
float64x2_t virtual_addend_real_f64x2 = vsubq_f64(tentative_sum_real_f64x2, sum_real_f64x2);
|
|
411
|
+
float64x2_t error_sum_real_f64x2 = vaddq_f64(
|
|
412
|
+
vsubq_f64(sum_real_f64x2, vsubq_f64(tentative_sum_real_f64x2, virtual_addend_real_f64x2)),
|
|
413
|
+
vsubq_f64(product_rr_f64x2, virtual_addend_real_f64x2));
|
|
414
|
+
sum_real_f64x2 = tentative_sum_real_f64x2;
|
|
415
|
+
compensation_real_f64x2 = vaddq_f64(compensation_real_f64x2, vaddq_f64(error_sum_real_f64x2, error_rr_f64x2));
|
|
416
|
+
// Second term: +aᵢ × bᵢ (conjugate: add instead of subtract)
|
|
417
|
+
float64x2_t product_ii_f64x2 = vmulq_f64(a_imag_f64x2, b_imag_f64x2);
|
|
418
|
+
float64x2_t error_ii_f64x2 = vnegq_f64(vfmsq_f64(product_ii_f64x2, a_imag_f64x2, b_imag_f64x2));
|
|
419
|
+
tentative_sum_real_f64x2 = vaddq_f64(sum_real_f64x2, product_ii_f64x2);
|
|
420
|
+
virtual_addend_real_f64x2 = vsubq_f64(tentative_sum_real_f64x2, sum_real_f64x2);
|
|
421
|
+
error_sum_real_f64x2 = vaddq_f64(
|
|
422
|
+
vsubq_f64(sum_real_f64x2, vsubq_f64(tentative_sum_real_f64x2, virtual_addend_real_f64x2)),
|
|
423
|
+
vsubq_f64(product_ii_f64x2, virtual_addend_real_f64x2));
|
|
424
|
+
sum_real_f64x2 = tentative_sum_real_f64x2;
|
|
425
|
+
compensation_real_f64x2 = vaddq_f64(compensation_real_f64x2, vaddq_f64(error_sum_real_f64x2, error_ii_f64x2));
|
|
426
|
+
|
|
427
|
+
// Imag part: aᵣ × bᵢ - aᵢ × bᵣ (using TwoProd and TwoSum)
|
|
428
|
+
// First term: +aᵣ × bᵢ
|
|
429
|
+
float64x2_t product_ri_f64x2 = vmulq_f64(a_real_f64x2, b_imag_f64x2);
|
|
430
|
+
float64x2_t error_ri_f64x2 = vnegq_f64(vfmsq_f64(product_ri_f64x2, a_real_f64x2, b_imag_f64x2));
|
|
431
|
+
float64x2_t tentative_sum_imag_f64x2 = vaddq_f64(sum_imag_f64x2, product_ri_f64x2);
|
|
432
|
+
float64x2_t virtual_addend_imag_f64x2 = vsubq_f64(tentative_sum_imag_f64x2, sum_imag_f64x2);
|
|
433
|
+
float64x2_t error_sum_imag_f64x2 = vaddq_f64(
|
|
434
|
+
vsubq_f64(sum_imag_f64x2, vsubq_f64(tentative_sum_imag_f64x2, virtual_addend_imag_f64x2)),
|
|
435
|
+
vsubq_f64(product_ri_f64x2, virtual_addend_imag_f64x2));
|
|
436
|
+
sum_imag_f64x2 = tentative_sum_imag_f64x2;
|
|
437
|
+
compensation_imag_f64x2 = vaddq_f64(compensation_imag_f64x2, vaddq_f64(error_sum_imag_f64x2, error_ri_f64x2));
|
|
438
|
+
// Second term: -aᵢ × bᵣ (conjugate: negate product and error, then standard TwoSum)
|
|
439
|
+
float64x2_t product_ir_f64x2 = vmulq_f64(a_imag_f64x2, b_real_f64x2);
|
|
440
|
+
float64x2_t error_ir_f64x2 = vnegq_f64(vfmsq_f64(product_ir_f64x2, a_imag_f64x2, b_real_f64x2));
|
|
441
|
+
float64x2_t neg_product_ir_f64x2 = vnegq_f64(product_ir_f64x2);
|
|
442
|
+
float64x2_t neg_error_ir_f64x2 = vnegq_f64(error_ir_f64x2);
|
|
443
|
+
tentative_sum_imag_f64x2 = vaddq_f64(sum_imag_f64x2, neg_product_ir_f64x2);
|
|
444
|
+
virtual_addend_imag_f64x2 = vsubq_f64(tentative_sum_imag_f64x2, sum_imag_f64x2);
|
|
445
|
+
error_sum_imag_f64x2 = vaddq_f64(
|
|
446
|
+
vsubq_f64(sum_imag_f64x2, vsubq_f64(tentative_sum_imag_f64x2, virtual_addend_imag_f64x2)),
|
|
447
|
+
vsubq_f64(neg_product_ir_f64x2, virtual_addend_imag_f64x2));
|
|
448
|
+
sum_imag_f64x2 = tentative_sum_imag_f64x2;
|
|
449
|
+
compensation_imag_f64x2 = vaddq_f64(compensation_imag_f64x2, vaddq_f64(error_sum_imag_f64x2, neg_error_ir_f64x2));
|
|
450
|
+
|
|
451
|
+
if (count_pairs) goto nk_vdot_f64c_neon_cycle;
|
|
452
|
+
// Compensated horizontal reduction preserving Dot2 error tracking
|
|
453
|
+
result->real = nk_dot_stable_sum_f64x2_neon_(sum_real_f64x2, compensation_real_f64x2);
|
|
454
|
+
result->imag = nk_dot_stable_sum_f64x2_neon_(sum_imag_f64x2, compensation_imag_f64x2);
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
/**
|
|
458
|
+
* @brief Running state for 128-bit dot accumulation over f64 scalars on NEON.
|
|
459
|
+
*
|
|
460
|
+
* Uses the Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product.
|
|
461
|
+
*/
|
|
462
|
+
typedef struct nk_dot_f64x2_state_neon_t {
|
|
463
|
+
float64x2_t sum_f64x2;
|
|
464
|
+
float64x2_t compensation_f64x2;
|
|
465
|
+
} nk_dot_f64x2_state_neon_t;
|
|
466
|
+
|
|
467
|
+
NK_INTERNAL void nk_dot_f64x2_init_neon(nk_dot_f64x2_state_neon_t *state) {
|
|
468
|
+
state->sum_f64x2 = vdupq_n_f64(0);
|
|
469
|
+
state->compensation_f64x2 = vdupq_n_f64(0);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
NK_INTERNAL void nk_dot_f64x2_update_neon(nk_dot_f64x2_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
473
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
474
|
+
nk_unused_(depth_offset);
|
|
475
|
+
nk_unused_(active_dimensions);
|
|
476
|
+
float64x2_t sum_f64x2 = state->sum_f64x2;
|
|
477
|
+
float64x2_t compensation_f64x2 = state->compensation_f64x2;
|
|
478
|
+
float64x2_t a_f64x2 = vreinterpretq_f64_u64(a.u64x2);
|
|
479
|
+
float64x2_t b_f64x2 = vreinterpretq_f64_u64(b.u64x2);
|
|
480
|
+
|
|
481
|
+
// TwoProd: h = a × b, r = fma(a, b, -h) captures the rounding error
|
|
482
|
+
float64x2_t product_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
483
|
+
float64x2_t product_error_f64x2 = vnegq_f64(vfmsq_f64(product_f64x2, a_f64x2, b_f64x2));
|
|
484
|
+
|
|
485
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
486
|
+
float64x2_t tentative_sum_f64x2 = vaddq_f64(sum_f64x2, product_f64x2);
|
|
487
|
+
float64x2_t virtual_addend_f64x2 = vsubq_f64(tentative_sum_f64x2, sum_f64x2);
|
|
488
|
+
float64x2_t sum_error_f64x2 = vaddq_f64(vsubq_f64(sum_f64x2, vsubq_f64(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
489
|
+
vsubq_f64(product_f64x2, virtual_addend_f64x2));
|
|
490
|
+
|
|
491
|
+
// Update: sum = t, compensation += q + r
|
|
492
|
+
state->sum_f64x2 = tentative_sum_f64x2;
|
|
493
|
+
state->compensation_f64x2 = vaddq_f64(compensation_f64x2, vaddq_f64(sum_error_f64x2, product_error_f64x2));
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
NK_INTERNAL void nk_dot_f64x2_finalize_neon( //
|
|
497
|
+
nk_dot_f64x2_state_neon_t const *state_a, nk_dot_f64x2_state_neon_t const *state_b, //
|
|
498
|
+
nk_dot_f64x2_state_neon_t const *state_c, nk_dot_f64x2_state_neon_t const *state_d, //
|
|
499
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
500
|
+
nk_unused_(total_dimensions);
|
|
501
|
+
// Compensated horizontal reduction preserving Dot2 error tracking per state
|
|
502
|
+
result->f64s[0] = nk_dot_stable_sum_f64x2_neon_(state_a->sum_f64x2, state_a->compensation_f64x2);
|
|
503
|
+
result->f64s[1] = nk_dot_stable_sum_f64x2_neon_(state_b->sum_f64x2, state_b->compensation_f64x2);
|
|
504
|
+
result->f64s[2] = nk_dot_stable_sum_f64x2_neon_(state_c->sum_f64x2, state_c->compensation_f64x2);
|
|
505
|
+
result->f64s[3] = nk_dot_stable_sum_f64x2_neon_(state_d->sum_f64x2, state_d->compensation_f64x2);
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
#pragma endregion - Traditional Floats
|
|
509
|
+
|
|
510
|
+
#pragma region - Smaller Floats
|
|
511
|
+
|
|
512
|
+
NK_PUBLIC void nk_dot_bf16_neon(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
513
|
+
nk_f32_t *result) {
|
|
514
|
+
uint16x8_t a_u16x8, b_u16x8;
|
|
515
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
516
|
+
nk_dot_bf16_neon_cycle:
|
|
517
|
+
if (count_scalars < 8) {
|
|
518
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
519
|
+
nk_partial_load_b16x8_serial_(a_scalars, &a_vec, count_scalars);
|
|
520
|
+
nk_partial_load_b16x8_serial_(b_scalars, &b_vec, count_scalars);
|
|
521
|
+
a_u16x8 = a_vec.u16x8;
|
|
522
|
+
b_u16x8 = b_vec.u16x8;
|
|
523
|
+
count_scalars = 0;
|
|
524
|
+
}
|
|
525
|
+
else {
|
|
526
|
+
a_u16x8 = vld1q_u16((nk_u16_t const *)a_scalars);
|
|
527
|
+
b_u16x8 = vld1q_u16((nk_u16_t const *)b_scalars);
|
|
528
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
529
|
+
}
|
|
530
|
+
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
|
|
531
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
|
|
532
|
+
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
|
|
533
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
|
|
534
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
535
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
536
|
+
if (count_scalars) goto nk_dot_bf16_neon_cycle;
|
|
537
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
/**
|
|
541
|
+
* @brief Running state for 128-bit dot accumulation over bf16 scalars on plain NEON.
|
|
542
|
+
*
|
|
543
|
+
* Processes 8 bf16 values at a time (128 bits), converting to f32 via USHLL shift-16
|
|
544
|
+
* for accumulation without requiring the ARMv8.6-BF16 extension.
|
|
545
|
+
*/
|
|
546
|
+
typedef struct nk_dot_bf16x8_state_neon_t {
|
|
547
|
+
float32x4_t sum_f32x4;
|
|
548
|
+
} nk_dot_bf16x8_state_neon_t;
|
|
549
|
+
|
|
550
|
+
NK_INTERNAL void nk_dot_bf16x8_init_neon(nk_dot_bf16x8_state_neon_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
|
|
551
|
+
|
|
552
|
+
NK_INTERNAL void nk_dot_bf16x8_update_neon(nk_dot_bf16x8_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
553
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
554
|
+
nk_unused_(depth_offset);
|
|
555
|
+
nk_unused_(active_dimensions);
|
|
556
|
+
// Convert bf16 to f32 via USHLL shift-16 (low and high halves)
|
|
557
|
+
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.u16x8), 16));
|
|
558
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.u16x8), 16));
|
|
559
|
+
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.u16x8), 16));
|
|
560
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.u16x8), 16));
|
|
561
|
+
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
562
|
+
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
NK_INTERNAL void nk_dot_bf16x8_finalize_neon( //
|
|
566
|
+
nk_dot_bf16x8_state_neon_t const *state_a, nk_dot_bf16x8_state_neon_t const *state_b, //
|
|
567
|
+
nk_dot_bf16x8_state_neon_t const *state_c, nk_dot_bf16x8_state_neon_t const *state_d, //
|
|
568
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
569
|
+
nk_unused_(total_dimensions);
|
|
570
|
+
result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
|
|
571
|
+
result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
|
|
572
|
+
result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
|
|
573
|
+
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
NK_PUBLIC void nk_dot_f16_neon(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
577
|
+
nk_f32_t *result) {
|
|
578
|
+
uint16x8_t a_u16x8, b_u16x8;
|
|
579
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
580
|
+
nk_dot_f16_neon_cycle:
|
|
581
|
+
if (count_scalars < 8) {
|
|
582
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
583
|
+
nk_partial_load_b16x8_serial_(a_scalars, &a_vec, count_scalars);
|
|
584
|
+
nk_partial_load_b16x8_serial_(b_scalars, &b_vec, count_scalars);
|
|
585
|
+
a_u16x8 = a_vec.u16x8;
|
|
586
|
+
b_u16x8 = b_vec.u16x8;
|
|
587
|
+
count_scalars = 0;
|
|
588
|
+
}
|
|
589
|
+
else {
|
|
590
|
+
a_u16x8 = vld1q_u16((nk_u16_t const *)a_scalars);
|
|
591
|
+
b_u16x8 = vld1q_u16((nk_u16_t const *)b_scalars);
|
|
592
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
593
|
+
}
|
|
594
|
+
float32x4_t a_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(a_u16x8));
|
|
595
|
+
float32x4_t a_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(a_u16x8));
|
|
596
|
+
float32x4_t b_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(b_u16x8));
|
|
597
|
+
float32x4_t b_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(b_u16x8));
|
|
598
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
599
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
600
|
+
if (count_scalars) goto nk_dot_f16_neon_cycle;
|
|
601
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
/**
|
|
605
|
+
* @brief Running state for 128-bit dot accumulation over f16 scalars on plain NEON.
|
|
606
|
+
*
|
|
607
|
+
* Processes 8 f16 values at a time (128 bits), converting to f32 via integer bit
|
|
608
|
+
* manipulation for accumulation without requiring the ARMv8.2-A FP16 extension.
|
|
609
|
+
*/
|
|
610
|
+
typedef struct nk_dot_f16x8_state_neon_t {
|
|
611
|
+
float32x4_t sum_f32x4;
|
|
612
|
+
} nk_dot_f16x8_state_neon_t;
|
|
613
|
+
|
|
614
|
+
NK_INTERNAL void nk_dot_f16x8_init_neon(nk_dot_f16x8_state_neon_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
|
|
615
|
+
|
|
616
|
+
NK_INTERNAL void nk_dot_f16x8_update_neon(nk_dot_f16x8_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
617
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
618
|
+
nk_unused_(depth_offset);
|
|
619
|
+
nk_unused_(active_dimensions);
|
|
620
|
+
// Convert f16 to f32 via integer bit manipulation (low and high halves)
|
|
621
|
+
float32x4_t a_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(a.u16x8));
|
|
622
|
+
float32x4_t a_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(a.u16x8));
|
|
623
|
+
float32x4_t b_low_f32x4 = nk_f16x4_to_f32x4_neon_(vget_low_u16(b.u16x8));
|
|
624
|
+
float32x4_t b_high_f32x4 = nk_f16x4_to_f32x4_neon_(vget_high_u16(b.u16x8));
|
|
625
|
+
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
626
|
+
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
NK_INTERNAL void nk_dot_f16x8_finalize_neon( //
|
|
630
|
+
nk_dot_f16x8_state_neon_t const *state_a, nk_dot_f16x8_state_neon_t const *state_b, //
|
|
631
|
+
nk_dot_f16x8_state_neon_t const *state_c, nk_dot_f16x8_state_neon_t const *state_d, //
|
|
632
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
633
|
+
nk_unused_(total_dimensions);
|
|
634
|
+
result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
|
|
635
|
+
result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
|
|
636
|
+
result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
|
|
637
|
+
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
NK_PUBLIC void nk_dot_e4m3_neon(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
641
|
+
nk_f32_t *result) {
|
|
642
|
+
float16x8_t a_f16x8, b_f16x8;
|
|
643
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
644
|
+
nk_dot_e4m3_neon_cycle:
|
|
645
|
+
if (count_scalars < 8) {
|
|
646
|
+
nk_b64_vec_t a_vec, b_vec;
|
|
647
|
+
nk_partial_load_b8x8_serial_(a_scalars, &a_vec, count_scalars);
|
|
648
|
+
nk_partial_load_b8x8_serial_(b_scalars, &b_vec, count_scalars);
|
|
649
|
+
a_f16x8 = nk_e4m3x8_to_f16x8_neon_(a_vec.u8x8);
|
|
650
|
+
b_f16x8 = nk_e4m3x8_to_f16x8_neon_(b_vec.u8x8);
|
|
651
|
+
count_scalars = 0;
|
|
652
|
+
}
|
|
653
|
+
else {
|
|
654
|
+
a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a_scalars));
|
|
655
|
+
b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b_scalars));
|
|
656
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
657
|
+
}
|
|
658
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
659
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
660
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
661
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
662
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
663
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
664
|
+
if (count_scalars) goto nk_dot_e4m3_neon_cycle;
|
|
665
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
NK_PUBLIC void nk_dot_e5m2_neon(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
669
|
+
nk_f32_t *result) {
|
|
670
|
+
float16x8_t a_f16x8, b_f16x8;
|
|
671
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
672
|
+
nk_dot_e5m2_neon_cycle:
|
|
673
|
+
if (count_scalars < 8) {
|
|
674
|
+
nk_b64_vec_t a_vec, b_vec;
|
|
675
|
+
nk_partial_load_b8x8_serial_(a_scalars, &a_vec, count_scalars);
|
|
676
|
+
nk_partial_load_b8x8_serial_(b_scalars, &b_vec, count_scalars);
|
|
677
|
+
a_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(a_vec.u8x8, 8));
|
|
678
|
+
b_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(b_vec.u8x8, 8));
|
|
679
|
+
count_scalars = 0;
|
|
680
|
+
}
|
|
681
|
+
else {
|
|
682
|
+
a_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vld1_u8(a_scalars), 8));
|
|
683
|
+
b_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vld1_u8(b_scalars), 8));
|
|
684
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
685
|
+
}
|
|
686
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
687
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
688
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
689
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
690
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
691
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
692
|
+
if (count_scalars) goto nk_dot_e5m2_neon_cycle;
|
|
693
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
NK_PUBLIC void nk_dot_e2m3_neon(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
697
|
+
nk_f32_t *result) {
|
|
698
|
+
float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
|
|
699
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
700
|
+
// x16 TBL path: process 16 elements per iteration via lookup table upcast
|
|
701
|
+
nk_dot_e2m3_neon_cycle:
|
|
702
|
+
if (count_scalars < 16) {
|
|
703
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
704
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
705
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
706
|
+
nk_e2m3x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low_f16x8, &a_high_f16x8);
|
|
707
|
+
nk_e2m3x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low_f16x8, &b_high_f16x8);
|
|
708
|
+
count_scalars = 0;
|
|
709
|
+
}
|
|
710
|
+
else {
|
|
711
|
+
nk_e2m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low_f16x8, &a_high_f16x8);
|
|
712
|
+
nk_e2m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low_f16x8, &b_high_f16x8);
|
|
713
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
714
|
+
}
|
|
715
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
|
|
716
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_low_f16x8)),
|
|
717
|
+
vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
|
|
718
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
|
|
719
|
+
vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
|
|
720
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_high_f16x8)),
|
|
721
|
+
vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
|
|
722
|
+
if (count_scalars) goto nk_dot_e2m3_neon_cycle;
|
|
723
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
NK_PUBLIC void nk_dot_e3m2_neon(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
727
|
+
nk_f32_t *result) {
|
|
728
|
+
float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
|
|
729
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
730
|
+
// x16 TBL path: process 16 elements per iteration via lookup table upcast
|
|
731
|
+
nk_dot_e3m2_neon_cycle:
|
|
732
|
+
if (count_scalars < 16) {
|
|
733
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
734
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
735
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
736
|
+
nk_e3m2x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low_f16x8, &a_high_f16x8);
|
|
737
|
+
nk_e3m2x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low_f16x8, &b_high_f16x8);
|
|
738
|
+
count_scalars = 0;
|
|
739
|
+
}
|
|
740
|
+
else {
|
|
741
|
+
nk_e3m2x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low_f16x8, &a_high_f16x8);
|
|
742
|
+
nk_e3m2x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low_f16x8, &b_high_f16x8);
|
|
743
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
744
|
+
}
|
|
745
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
|
|
746
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_low_f16x8)),
|
|
747
|
+
vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
|
|
748
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
|
|
749
|
+
vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
|
|
750
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_high_f16(a_high_f16x8)),
|
|
751
|
+
vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
|
|
752
|
+
if (count_scalars) goto nk_dot_e3m2_neon_cycle;
|
|
753
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
#pragma endregion - Smaller Floats
|
|
757
|
+
|
|
758
|
+
#pragma region - Binary
|
|
759
|
+
|
|
760
|
+
NK_PUBLIC void nk_dot_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
761
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
762
|
+
nk_u32_t dot = 0;
|
|
763
|
+
nk_size_t i = 0;
|
|
764
|
+
while (i + 16 <= n_bytes) {
|
|
765
|
+
uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
|
|
766
|
+
for (nk_size_t cycle = 0; cycle < 31 && i + 16 <= n_bytes; ++cycle, i += 16) {
|
|
767
|
+
uint8x16_t a_u8x16 = vld1q_u8(a + i);
|
|
768
|
+
uint8x16_t b_u8x16 = vld1q_u8(b + i);
|
|
769
|
+
popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(vandq_u8(a_u8x16, b_u8x16)));
|
|
770
|
+
}
|
|
771
|
+
dot += (nk_u32_t)vaddlvq_u8(popcount_u8x16);
|
|
772
|
+
}
|
|
773
|
+
for (; i != n_bytes; ++i) dot += nk_u1x8_popcount_(a[i] & b[i]);
|
|
774
|
+
*result = dot;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
typedef struct nk_dot_u1x128_state_neon_t {
|
|
778
|
+
uint32x4_t dot_count_u32x4;
|
|
779
|
+
} nk_dot_u1x128_state_neon_t;
|
|
780
|
+
|
|
781
|
+
NK_INTERNAL void nk_dot_u1x128_init_neon(nk_dot_u1x128_state_neon_t *state) { state->dot_count_u32x4 = vdupq_n_u32(0); }
|
|
782
|
+
|
|
783
|
+
NK_INTERNAL void nk_dot_u1x128_update_neon(nk_dot_u1x128_state_neon_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
784
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
785
|
+
nk_unused_(depth_offset);
|
|
786
|
+
nk_unused_(active_dimensions);
|
|
787
|
+
uint8x16_t and_u8x16 = vandq_u8(a.u8x16, b.u8x16);
|
|
788
|
+
uint8x16_t popcount_u8x16 = vcntq_u8(and_u8x16);
|
|
789
|
+
uint16x8_t popcount_u16x8 = vpaddlq_u8(popcount_u8x16);
|
|
790
|
+
uint32x4_t popcount_u32x4 = vpaddlq_u16(popcount_u16x8);
|
|
791
|
+
state->dot_count_u32x4 = vaddq_u32(state->dot_count_u32x4, popcount_u32x4);
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
NK_INTERNAL void nk_dot_u1x128_finalize_neon( //
|
|
795
|
+
nk_dot_u1x128_state_neon_t const *state_a, nk_dot_u1x128_state_neon_t const *state_b,
|
|
796
|
+
nk_dot_u1x128_state_neon_t const *state_c, nk_dot_u1x128_state_neon_t const *state_d, nk_size_t total_dimensions,
|
|
797
|
+
nk_b128_vec_t *result) {
|
|
798
|
+
nk_unused_(total_dimensions);
|
|
799
|
+
uint32x4_t ab_sum_u32x4 = vpaddq_u32(state_a->dot_count_u32x4, state_b->dot_count_u32x4);
|
|
800
|
+
uint32x4_t cd_sum_u32x4 = vpaddq_u32(state_c->dot_count_u32x4, state_d->dot_count_u32x4);
|
|
801
|
+
result->u32x4 = vpaddq_u32(ab_sum_u32x4, cd_sum_u32x4);
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
#pragma endregion - Binary
|
|
805
|
+
|
|
806
|
+
#if defined(__clang__)
|
|
807
|
+
#pragma clang attribute pop
|
|
808
|
+
#elif defined(__GNUC__)
|
|
809
|
+
#pragma GCC pop_options
|
|
810
|
+
#endif
|
|
811
|
+
|
|
812
|
+
#if defined(__cplusplus)
|
|
813
|
+
} // extern "C"
|
|
814
|
+
#endif
|
|
815
|
+
|
|
816
|
+
#endif // NK_TARGET_NEON
|
|
817
|
+
#endif // NK_TARGET_ARM_
|
|
818
|
+
#endif // NK_DOT_NEON_H
|