numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for NEON FP16.
|
|
3
|
+
* @file include/numkong/dot/neonhalf.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* A76 M4+/V1+/Oryon
|
|
13
|
+
* vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
|
|
14
|
+
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
|
|
15
|
+
* vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
|
|
16
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
17
|
+
* vfmsq_f16 FMLS (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
|
|
18
|
+
*
|
|
19
|
+
* The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
|
|
20
|
+
* per vector register (8x F16 vs 4x F32). This doubles theoretical throughput for bandwidth-bound
|
|
21
|
+
* workloads while halving memory footprint.
|
|
22
|
+
*
|
|
23
|
+
* For dot products, inputs are widened from F16 to F32 for accumulation to preserve numerical
|
|
24
|
+
* precision. The FCVTL instruction handles this widening, allowing the FMA operations
|
|
25
|
+
* to maintain full F32 precision in the accumulator.
|
|
26
|
+
*
|
|
27
|
+
* @section dot_neonhalf_stateful Stateful Streaming Logic
|
|
28
|
+
*
|
|
29
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
30
|
+
* `NK_INTERNAL` functions:
|
|
31
|
+
*
|
|
32
|
+
* - nk_dot_f16x4 state with f16 inputs widened to f32 for accumulation.
|
|
33
|
+
*
|
|
34
|
+
* @code{c}
|
|
35
|
+
* nk_dot_f16x4_state_neonhalf_t state_first, state_second, state_third, state_fourth;
|
|
36
|
+
* float16x4_t query_f16x4, target_first_f16x4, target_second_f16x4, target_third_f16x4, target_fourth_f16x4;
|
|
37
|
+
* nk_dot_f16x4_init_neonhalf(&state_first);
|
|
38
|
+
* nk_dot_f16x4_init_neonhalf(&state_second);
|
|
39
|
+
* nk_dot_f16x4_init_neonhalf(&state_third);
|
|
40
|
+
* nk_dot_f16x4_init_neonhalf(&state_fourth);
|
|
41
|
+
* for (nk_size_t idx = 0; idx + 4 <= depth; idx += 4) {
|
|
42
|
+
* query_f16x4 = vld1_f16(query_ptr + idx);
|
|
43
|
+
* target_first_f16x4 = vld1_f16(target_first_ptr + idx);
|
|
44
|
+
* target_second_f16x4 = vld1_f16(target_second_ptr + idx);
|
|
45
|
+
* target_third_f16x4 = vld1_f16(target_third_ptr + idx);
|
|
46
|
+
* target_fourth_f16x4 = vld1_f16(target_fourth_ptr + idx);
|
|
47
|
+
* nk_dot_f16x4_update_neonhalf(&state_first, query_f16x4, target_first_f16x4, idx, 4);
|
|
48
|
+
* nk_dot_f16x4_update_neonhalf(&state_second, query_f16x4, target_second_f16x4, idx, 4);
|
|
49
|
+
* nk_dot_f16x4_update_neonhalf(&state_third, query_f16x4, target_third_f16x4, idx, 4);
|
|
50
|
+
* nk_dot_f16x4_update_neonhalf(&state_fourth, query_f16x4, target_fourth_f16x4, idx, 4);
|
|
51
|
+
* }
|
|
52
|
+
* float32x4_t results_f32x4;
|
|
53
|
+
* nk_dot_f16x4_finalize_neonhalf(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
|
|
54
|
+
* @endcode
|
|
55
|
+
*/
|
|
56
|
+
#ifndef NK_DOT_NEONHALF_H
|
|
57
|
+
#define NK_DOT_NEONHALF_H
|
|
58
|
+
|
|
59
|
+
#if NK_TARGET_ARM_
|
|
60
|
+
#if NK_TARGET_NEONHALF
|
|
61
|
+
|
|
62
|
+
#include "numkong/types.h"
|
|
63
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
|
|
64
|
+
|
|
65
|
+
#if defined(__cplusplus)
|
|
66
|
+
extern "C" {
|
|
67
|
+
#endif
|
|
68
|
+
|
|
69
|
+
#if defined(__clang__)
|
|
70
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
71
|
+
#elif defined(__GNUC__)
|
|
72
|
+
#pragma GCC push_options
|
|
73
|
+
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
74
|
+
#endif
|
|
75
|
+
|
|
76
|
+
NK_PUBLIC void nk_dot_f16_neonhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
77
|
+
nk_f32_t *result) {
|
|
78
|
+
float32x4_t a_f32x4, b_f32x4;
|
|
79
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
80
|
+
nk_dot_f16_neonhalf_cycle:
|
|
81
|
+
if (count_scalars < 4) {
|
|
82
|
+
nk_b64_vec_t a_vec, b_vec;
|
|
83
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_vec, count_scalars);
|
|
84
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_vec, count_scalars);
|
|
85
|
+
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
86
|
+
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
87
|
+
count_scalars = 0;
|
|
88
|
+
}
|
|
89
|
+
else {
|
|
90
|
+
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a_scalars));
|
|
91
|
+
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b_scalars));
|
|
92
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
93
|
+
}
|
|
94
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_f32x4, b_f32x4);
|
|
95
|
+
if (count_scalars) goto nk_dot_f16_neonhalf_cycle;
|
|
96
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
NK_PUBLIC void nk_dot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
100
|
+
nk_f32c_t *result) {
|
|
101
|
+
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
102
|
+
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
103
|
+
while (count_pairs >= 4) {
|
|
104
|
+
// Unpack the input arrays into real and imaginary parts.
|
|
105
|
+
// MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
|
|
106
|
+
// integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
|
|
107
|
+
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
108
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
109
|
+
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
110
|
+
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
111
|
+
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
112
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
113
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
114
|
+
sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
115
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
116
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
117
|
+
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
118
|
+
}
|
|
119
|
+
// Reduce horizontal sums and aggregate with the tail:
|
|
120
|
+
nk_f32c_t tail_result;
|
|
121
|
+
nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
122
|
+
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
123
|
+
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
127
|
+
nk_f32c_t *result) {
|
|
128
|
+
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
129
|
+
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
130
|
+
while (count_pairs >= 4) {
|
|
131
|
+
// Unpack the input arrays into real and imaginary parts.
|
|
132
|
+
// MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
|
|
133
|
+
// integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
|
|
134
|
+
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
135
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
136
|
+
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
137
|
+
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
138
|
+
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
139
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
140
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
141
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
142
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
143
|
+
sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
144
|
+
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
145
|
+
}
|
|
146
|
+
// Reduce horizontal sums and aggregate with the tail:
|
|
147
|
+
nk_f32c_t tail_result;
|
|
148
|
+
nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
149
|
+
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
150
|
+
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
/**
|
|
154
|
+
* @brief Running state for 64-bit dot accumulation over f16 scalars on NEON with FP16 extension.
|
|
155
|
+
*
|
|
156
|
+
* Processes 4 f16 values at a time (64 bits), converting directly to f32 without
|
|
157
|
+
* the overhead of vget_low/vget_high operations on 128-bit vectors.
|
|
158
|
+
*/
|
|
159
|
+
typedef struct nk_dot_f16x4_state_neonhalf_t {
|
|
160
|
+
float32x4_t sum_f32x4;
|
|
161
|
+
} nk_dot_f16x4_state_neonhalf_t;
|
|
162
|
+
|
|
163
|
+
NK_INTERNAL void nk_dot_f16x4_init_neonhalf(nk_dot_f16x4_state_neonhalf_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
|
|
164
|
+
|
|
165
|
+
NK_INTERNAL void nk_dot_f16x4_update_neonhalf(nk_dot_f16x4_state_neonhalf_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
166
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
167
|
+
nk_unused_(depth_offset);
|
|
168
|
+
nk_unused_(active_dimensions);
|
|
169
|
+
// 4 f16s = 64 bits, direct conversion without low/high split
|
|
170
|
+
float16x4_t a_f16x4 = vreinterpret_f16_u16(a.u16x4);
|
|
171
|
+
float16x4_t b_f16x4 = vreinterpret_f16_u16(b.u16x4);
|
|
172
|
+
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, vcvt_f32_f16(a_f16x4), vcvt_f32_f16(b_f16x4));
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
NK_INTERNAL void nk_dot_f16x4_finalize_neonhalf( //
|
|
176
|
+
nk_dot_f16x4_state_neonhalf_t const *state_a, nk_dot_f16x4_state_neonhalf_t const *state_b, //
|
|
177
|
+
nk_dot_f16x4_state_neonhalf_t const *state_c, nk_dot_f16x4_state_neonhalf_t const *state_d, //
|
|
178
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
179
|
+
nk_unused_(total_dimensions);
|
|
180
|
+
result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
|
|
181
|
+
result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
|
|
182
|
+
result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
|
|
183
|
+
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
#if defined(__clang__)
|
|
187
|
+
#pragma clang attribute pop
|
|
188
|
+
#elif defined(__GNUC__)
|
|
189
|
+
#pragma GCC pop_options
|
|
190
|
+
#endif
|
|
191
|
+
|
|
192
|
+
#if defined(__cplusplus)
|
|
193
|
+
} // extern "C"
|
|
194
|
+
#endif
|
|
195
|
+
|
|
196
|
+
#endif // NK_TARGET_NEONHALF
|
|
197
|
+
#endif // NK_TARGET_ARM_
|
|
198
|
+
#endif // NK_DOT_NEONHALF_H
|