numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for NEON SDOT.
|
|
3
|
+
* @file include/numkong/dot/neonsdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_neonsdot_instructions ARM NEON SDOT/UDOT Instructions (ARMv8.4-DotProd)
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* A76 M4+/V1+/Oryon
|
|
13
|
+
* vdotq_s32 SDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
|
|
14
|
+
* vdotq_u32 UDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
|
|
15
|
+
* vld1q_s8 LD1 (V.16B) 4cy 2/cy 3/cy
|
|
16
|
+
* vld1q_u8 LD1 (V.16B) 4cy 2/cy 3/cy
|
|
17
|
+
* vaddvq_s32 ADDV (V.4S) 4cy 1/cy 2/cy
|
|
18
|
+
* vaddvq_u32 ADDV (V.4S) 4cy 1/cy 2/cy
|
|
19
|
+
*
|
|
20
|
+
* The ARMv8.4-DotProd extension provides SDOT/UDOT instructions critical for int8 quantized ML
|
|
21
|
+
* inference. Each instruction computes four dot products of 4-element int8 vectors, accumulating
|
|
22
|
+
* into int32 lanes, processing 16 multiply-accumulates per instruction.
|
|
23
|
+
*
|
|
24
|
+
* SDOT handles signed int8 operands while UDOT handles unsigned. The 3-cycle latency with 2/cy
|
|
25
|
+
* throughput on A76 (4/cy on newer cores) enables int8 matrix multiplication for
|
|
26
|
+
* quantized neural network inference, where 8-bit weights reduce memory bandwidth by 4x vs FP32.
|
|
27
|
+
*
|
|
28
|
+
* @section dot_neonsdot_stateful Stateful Streaming Logic
|
|
29
|
+
*
|
|
30
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
31
|
+
* `NK_INTERNAL` functions:
|
|
32
|
+
*
|
|
33
|
+
* - nk_dot_i8x16 for 8-bit signed integer inputs using SDOT,
|
|
34
|
+
* - nk_dot_u8x16 for 8-bit unsigned integer inputs using UDOT,
|
|
35
|
+
* - nk_dot_i4x32 for 4-bit signed integer products,
|
|
36
|
+
* - nk_dot_u4x32 for 4-bit unsigned integer products.
|
|
37
|
+
*
|
|
38
|
+
* @code{c}
|
|
39
|
+
* nk_dot_i8x16_state_neonsdot_t state_first, state_second, state_third, state_fourth;
|
|
40
|
+
* int8x16_t query_i8x16, target_first_i8x16, target_second_i8x16, target_third_i8x16, target_fourth_i8x16;
|
|
41
|
+
* nk_dot_i8x16_init_neonsdot(&state_first);
|
|
42
|
+
* nk_dot_i8x16_init_neonsdot(&state_second);
|
|
43
|
+
* nk_dot_i8x16_init_neonsdot(&state_third);
|
|
44
|
+
* nk_dot_i8x16_init_neonsdot(&state_fourth);
|
|
45
|
+
* for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
|
|
46
|
+
* query_i8x16 = vld1q_s8(query_ptr + idx);
|
|
47
|
+
* target_first_i8x16 = vld1q_s8(target_first_ptr + idx);
|
|
48
|
+
* target_second_i8x16 = vld1q_s8(target_second_ptr + idx);
|
|
49
|
+
* target_third_i8x16 = vld1q_s8(target_third_ptr + idx);
|
|
50
|
+
* target_fourth_i8x16 = vld1q_s8(target_fourth_ptr + idx);
|
|
51
|
+
* nk_dot_i8x16_update_neonsdot(&state_first, query_i8x16, target_first_i8x16, idx, 16);
|
|
52
|
+
* nk_dot_i8x16_update_neonsdot(&state_second, query_i8x16, target_second_i8x16, idx, 16);
|
|
53
|
+
* nk_dot_i8x16_update_neonsdot(&state_third, query_i8x16, target_third_i8x16, idx, 16);
|
|
54
|
+
* nk_dot_i8x16_update_neonsdot(&state_fourth, query_i8x16, target_fourth_i8x16, idx, 16);
|
|
55
|
+
* }
|
|
56
|
+
* int32x4_t results_i32x4;
|
|
57
|
+
* nk_dot_i8x16_finalize_neonsdot(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
|
|
58
|
+
* @endcode
|
|
59
|
+
*
|
|
60
|
+
* For 4-bit integers, the state manages unpacking and accumulation:
|
|
61
|
+
*
|
|
62
|
+
* @code{c}
|
|
63
|
+
* nk_dot_i4x32_state_neonsdot_t state_first, state_second, state_third, state_fourth;
|
|
64
|
+
* uint8x8_t query_packed, target_first_packed, target_second_packed, target_third_packed, target_fourth_packed;
|
|
65
|
+
* nk_dot_i4x32_init_neonsdot(&state_first);
|
|
66
|
+
* nk_dot_i4x32_init_neonsdot(&state_second);
|
|
67
|
+
* nk_dot_i4x32_init_neonsdot(&state_third);
|
|
68
|
+
* nk_dot_i4x32_init_neonsdot(&state_fourth);
|
|
69
|
+
* for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
|
|
70
|
+
* query_packed = vld1_u8(query_ptr + idx / 2);
|
|
71
|
+
* target_first_packed = vld1_u8(target_first_ptr + idx / 2);
|
|
72
|
+
* target_second_packed = vld1_u8(target_second_ptr + idx / 2);
|
|
73
|
+
* target_third_packed = vld1_u8(target_third_ptr + idx / 2);
|
|
74
|
+
* target_fourth_packed = vld1_u8(target_fourth_ptr + idx / 2);
|
|
75
|
+
* nk_dot_i4x32_update_neonsdot(&state_first, query_packed, target_first_packed, idx, 16);
|
|
76
|
+
* nk_dot_i4x32_update_neonsdot(&state_second, query_packed, target_second_packed, idx, 16);
|
|
77
|
+
* nk_dot_i4x32_update_neonsdot(&state_third, query_packed, target_third_packed, idx, 16);
|
|
78
|
+
* nk_dot_i4x32_update_neonsdot(&state_fourth, query_packed, target_fourth_packed, idx, 16);
|
|
79
|
+
* }
|
|
80
|
+
* int32x4_t results_i32x4;
|
|
81
|
+
* nk_dot_i4x32_finalize_neonsdot(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
|
|
82
|
+
* @endcode
|
|
83
|
+
*/
|
|
84
|
+
#ifndef NK_DOT_NEONSDOT_H
|
|
85
|
+
#define NK_DOT_NEONSDOT_H
|
|
86
|
+
|
|
87
|
+
#if NK_TARGET_ARM_
|
|
88
|
+
#if NK_TARGET_NEONSDOT
|
|
89
|
+
|
|
90
|
+
#include "numkong/types.h"
|
|
91
|
+
|
|
92
|
+
#if defined(__cplusplus)
|
|
93
|
+
extern "C" {
|
|
94
|
+
#endif
|
|
95
|
+
|
|
96
|
+
#if defined(__clang__)
|
|
97
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function)
|
|
98
|
+
#elif defined(__GNUC__)
|
|
99
|
+
#pragma GCC push_options
|
|
100
|
+
#pragma GCC target("arch=armv8.2-a+dotprod")
|
|
101
|
+
#endif
|
|
102
|
+
|
|
103
|
+
NK_PUBLIC void nk_dot_i8_neonsdot(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
104
|
+
nk_i32_t *result) {
|
|
105
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
106
|
+
nk_size_t idx_scalars = 0;
|
|
107
|
+
for (; idx_scalars + 16 <= count_scalars; idx_scalars += 16) {
|
|
108
|
+
int8x16_t a_i8x16 = vld1q_s8(a_scalars + idx_scalars);
|
|
109
|
+
int8x16_t b_i8x16 = vld1q_s8(b_scalars + idx_scalars);
|
|
110
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, a_i8x16, b_i8x16);
|
|
111
|
+
}
|
|
112
|
+
nk_i32_t sum = vaddvq_s32(sum_i32x4);
|
|
113
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_i32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
114
|
+
*result = sum;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
NK_PUBLIC void nk_dot_u8_neonsdot(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
118
|
+
nk_u32_t *result) {
|
|
119
|
+
uint32x4_t sum_u32x4 = vdupq_n_u32(0);
|
|
120
|
+
nk_size_t idx_scalars = 0;
|
|
121
|
+
for (; idx_scalars + 16 <= count_scalars; idx_scalars += 16) {
|
|
122
|
+
uint8x16_t a_u8x16 = vld1q_u8(a_scalars + idx_scalars);
|
|
123
|
+
uint8x16_t b_u8x16 = vld1q_u8(b_scalars + idx_scalars);
|
|
124
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, a_u8x16, b_u8x16);
|
|
125
|
+
}
|
|
126
|
+
nk_u32_t sum = vaddvq_u32(sum_u32x4);
|
|
127
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_u32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
128
|
+
*result = sum;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
/**
|
|
132
|
+
* @brief Running state for 128-bit dot accumulation over i8 scalars on NEON.
|
|
133
|
+
*/
|
|
134
|
+
typedef struct nk_dot_i8x16_state_neonsdot_t {
|
|
135
|
+
int32x4_t sum_i32x4;
|
|
136
|
+
} nk_dot_i8x16_state_neonsdot_t;
|
|
137
|
+
|
|
138
|
+
NK_INTERNAL void nk_dot_i8x16_init_neonsdot(nk_dot_i8x16_state_neonsdot_t *state) { state->sum_i32x4 = vdupq_n_s32(0); }
|
|
139
|
+
|
|
140
|
+
NK_INTERNAL void nk_dot_i8x16_update_neonsdot(nk_dot_i8x16_state_neonsdot_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
141
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
142
|
+
nk_unused_(depth_offset);
|
|
143
|
+
nk_unused_(active_dimensions);
|
|
144
|
+
int32x4_t sum_i32x4 = state->sum_i32x4;
|
|
145
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, vreinterpretq_s8_u32(a.u32x4), vreinterpretq_s8_u32(b.u32x4));
|
|
146
|
+
state->sum_i32x4 = sum_i32x4;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
NK_INTERNAL void nk_dot_i8x16_finalize_neonsdot( //
|
|
150
|
+
nk_dot_i8x16_state_neonsdot_t const *state_a, nk_dot_i8x16_state_neonsdot_t const *state_b, //
|
|
151
|
+
nk_dot_i8x16_state_neonsdot_t const *state_c, nk_dot_i8x16_state_neonsdot_t const *state_d, //
|
|
152
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
153
|
+
nk_unused_(total_dimensions);
|
|
154
|
+
result->i32s[0] = vaddvq_s32(state_a->sum_i32x4);
|
|
155
|
+
result->i32s[1] = vaddvq_s32(state_b->sum_i32x4);
|
|
156
|
+
result->i32s[2] = vaddvq_s32(state_c->sum_i32x4);
|
|
157
|
+
result->i32s[3] = vaddvq_s32(state_d->sum_i32x4);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/**
|
|
161
|
+
* @brief Running state for 128-bit dot accumulation over u8 scalars on NEON.
|
|
162
|
+
*/
|
|
163
|
+
typedef struct nk_dot_u8x16_state_neonsdot_t {
|
|
164
|
+
uint32x4_t sum_u32x4;
|
|
165
|
+
} nk_dot_u8x16_state_neonsdot_t;
|
|
166
|
+
|
|
167
|
+
NK_INTERNAL void nk_dot_u8x16_init_neonsdot(nk_dot_u8x16_state_neonsdot_t *state) { state->sum_u32x4 = vdupq_n_u32(0); }
|
|
168
|
+
|
|
169
|
+
NK_INTERNAL void nk_dot_u8x16_update_neonsdot(nk_dot_u8x16_state_neonsdot_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
170
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
171
|
+
nk_unused_(depth_offset);
|
|
172
|
+
nk_unused_(active_dimensions);
|
|
173
|
+
uint32x4_t sum_u32x4 = state->sum_u32x4;
|
|
174
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, vreinterpretq_u8_u32(a.u32x4), vreinterpretq_u8_u32(b.u32x4));
|
|
175
|
+
state->sum_u32x4 = sum_u32x4;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
NK_INTERNAL void nk_dot_u8x16_finalize_neonsdot( //
|
|
179
|
+
nk_dot_u8x16_state_neonsdot_t const *state_a, nk_dot_u8x16_state_neonsdot_t const *state_b, //
|
|
180
|
+
nk_dot_u8x16_state_neonsdot_t const *state_c, nk_dot_u8x16_state_neonsdot_t const *state_d, //
|
|
181
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
182
|
+
nk_unused_(total_dimensions);
|
|
183
|
+
result->u32s[0] = vaddvq_u32(state_a->sum_u32x4);
|
|
184
|
+
result->u32s[1] = vaddvq_u32(state_b->sum_u32x4);
|
|
185
|
+
result->u32s[2] = vaddvq_u32(state_c->sum_u32x4);
|
|
186
|
+
result->u32s[3] = vaddvq_u32(state_d->sum_u32x4);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
NK_PUBLIC void nk_dot_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
190
|
+
// i4 values are packed as nibbles: two 4-bit signed values per byte.
|
|
191
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
192
|
+
//
|
|
193
|
+
// ARM NEON SDOT handles signed×signed directly, so we use direct sign-extension:
|
|
194
|
+
// Extract nibbles [0,15], sign-extend to i8 [-8,7] via shift trick, then SDOT.
|
|
195
|
+
// No algebraic correction needed unlike x86 DPBUSD.
|
|
196
|
+
//
|
|
197
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
198
|
+
nk_size_t n_bytes = n / 2;
|
|
199
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
200
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
201
|
+
uint8x16_t a_i4x32, b_i4x32;
|
|
202
|
+
|
|
203
|
+
nk_dot_i4_neonsdot_cycle:
|
|
204
|
+
if (n_bytes < 16) {
|
|
205
|
+
// Partial load for tail handling
|
|
206
|
+
nk_b128_vec_t a_vec = {0}, b_vec = {0};
|
|
207
|
+
nk_u8_t const *a_ptr = (nk_u8_t const *)a;
|
|
208
|
+
nk_u8_t const *b_ptr = (nk_u8_t const *)b;
|
|
209
|
+
for (nk_size_t i = 0; i < n_bytes; i++) {
|
|
210
|
+
a_vec.u8s[i] = a_ptr[i];
|
|
211
|
+
b_vec.u8s[i] = b_ptr[i];
|
|
212
|
+
}
|
|
213
|
+
a_i4x32 = a_vec.u8x16;
|
|
214
|
+
b_i4x32 = b_vec.u8x16;
|
|
215
|
+
n_bytes = 0;
|
|
216
|
+
}
|
|
217
|
+
else {
|
|
218
|
+
a_i4x32 = vld1q_u8((nk_u8_t const *)a);
|
|
219
|
+
b_i4x32 = vld1q_u8((nk_u8_t const *)b);
|
|
220
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// Extract low and high nibbles as unsigned [0,15]
|
|
224
|
+
uint8x16_t a_lo_u8x16 = vandq_u8(a_i4x32, nibble_mask_u8x16);
|
|
225
|
+
uint8x16_t a_hi_u8x16 = vshrq_n_u8(a_i4x32, 4);
|
|
226
|
+
uint8x16_t b_lo_u8x16 = vandq_u8(b_i4x32, nibble_mask_u8x16);
|
|
227
|
+
uint8x16_t b_hi_u8x16 = vshrq_n_u8(b_i4x32, 4);
|
|
228
|
+
|
|
229
|
+
// Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
|
|
230
|
+
// This converts unsigned [0,15] to signed [-8,7]
|
|
231
|
+
int8x16_t a_lo_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_lo_u8x16), 4), 4);
|
|
232
|
+
int8x16_t a_hi_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_hi_u8x16), 4), 4);
|
|
233
|
+
int8x16_t b_lo_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_lo_u8x16), 4), 4);
|
|
234
|
+
int8x16_t b_hi_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_hi_u8x16), 4), 4);
|
|
235
|
+
|
|
236
|
+
// SDOT for signed dot product - no correction needed!
|
|
237
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, a_lo_i8x16, b_lo_i8x16);
|
|
238
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, a_hi_i8x16, b_hi_i8x16);
|
|
239
|
+
|
|
240
|
+
if (n_bytes) goto nk_dot_i4_neonsdot_cycle;
|
|
241
|
+
|
|
242
|
+
*result = vaddvq_s32(sum_i32x4);
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
NK_PUBLIC void nk_dot_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
246
|
+
// u4 values are packed as nibbles: two 4-bit unsigned values per byte.
|
|
247
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
248
|
+
// Values are ∈ [0,15], so UDOT can be used directly.
|
|
249
|
+
//
|
|
250
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
251
|
+
nk_size_t n_bytes = n / 2;
|
|
252
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
253
|
+
uint32x4_t sum_u32x4 = vdupq_n_u32(0);
|
|
254
|
+
uint8x16_t a_u4x32, b_u4x32;
|
|
255
|
+
|
|
256
|
+
nk_dot_u4_neonsdot_cycle:
|
|
257
|
+
if (n_bytes < 16) {
|
|
258
|
+
// Partial load for tail handling
|
|
259
|
+
nk_b128_vec_t a_vec = {0}, b_vec = {0};
|
|
260
|
+
nk_u8_t const *a_ptr = (nk_u8_t const *)a;
|
|
261
|
+
nk_u8_t const *b_ptr = (nk_u8_t const *)b;
|
|
262
|
+
for (nk_size_t i = 0; i < n_bytes; i++) {
|
|
263
|
+
a_vec.u8s[i] = a_ptr[i];
|
|
264
|
+
b_vec.u8s[i] = b_ptr[i];
|
|
265
|
+
}
|
|
266
|
+
a_u4x32 = a_vec.u8x16;
|
|
267
|
+
b_u4x32 = b_vec.u8x16;
|
|
268
|
+
n_bytes = 0;
|
|
269
|
+
}
|
|
270
|
+
else {
|
|
271
|
+
a_u4x32 = vld1q_u8((nk_u8_t const *)a);
|
|
272
|
+
b_u4x32 = vld1q_u8((nk_u8_t const *)b);
|
|
273
|
+
a += 16, b += 16, n_bytes -= 16;
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
// Extract low and high nibbles - values in [0,15] work directly with UDOT
|
|
277
|
+
uint8x16_t a_lo_u8x16 = vandq_u8(a_u4x32, nibble_mask_u8x16);
|
|
278
|
+
uint8x16_t a_hi_u8x16 = vshrq_n_u8(a_u4x32, 4);
|
|
279
|
+
uint8x16_t b_lo_u8x16 = vandq_u8(b_u4x32, nibble_mask_u8x16);
|
|
280
|
+
uint8x16_t b_hi_u8x16 = vshrq_n_u8(b_u4x32, 4);
|
|
281
|
+
|
|
282
|
+
// UDOT directly on unsigned nibbles
|
|
283
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, a_lo_u8x16, b_lo_u8x16);
|
|
284
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, a_hi_u8x16, b_hi_u8x16);
|
|
285
|
+
|
|
286
|
+
if (n_bytes) goto nk_dot_u4_neonsdot_cycle;
|
|
287
|
+
|
|
288
|
+
*result = vaddvq_u32(sum_u32x4);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
typedef struct nk_dot_i4x32_state_neonsdot_t {
|
|
292
|
+
int32x4_t product_sum_i32x4;
|
|
293
|
+
} nk_dot_i4x32_state_neonsdot_t;
|
|
294
|
+
|
|
295
|
+
NK_INTERNAL void nk_dot_i4x32_init_neonsdot(nk_dot_i4x32_state_neonsdot_t *state) {
|
|
296
|
+
state->product_sum_i32x4 = vdupq_n_s32(0);
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
NK_INTERNAL void nk_dot_i4x32_update_neonsdot(nk_dot_i4x32_state_neonsdot_t *state, nk_b128_vec_t a_i4x32,
|
|
300
|
+
nk_b128_vec_t b_i4x32, nk_size_t depth_offset,
|
|
301
|
+
nk_size_t active_dimensions) {
|
|
302
|
+
nk_unused_(depth_offset);
|
|
303
|
+
nk_unused_(active_dimensions);
|
|
304
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
305
|
+
|
|
306
|
+
// Extract nibbles as unsigned first
|
|
307
|
+
uint8x16_t a_low_u8x16 = vandq_u8(a_i4x32.u8x16, nibble_mask_u8x16);
|
|
308
|
+
uint8x16_t a_high_u8x16 = vshrq_n_u8(a_i4x32.u8x16, 4);
|
|
309
|
+
uint8x16_t b_low_u8x16 = vandq_u8(b_i4x32.u8x16, nibble_mask_u8x16);
|
|
310
|
+
uint8x16_t b_high_u8x16 = vshrq_n_u8(b_i4x32.u8x16, 4);
|
|
311
|
+
|
|
312
|
+
// Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
|
|
313
|
+
int8x16_t a_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_low_u8x16), 4), 4);
|
|
314
|
+
int8x16_t a_high_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_high_u8x16), 4), 4);
|
|
315
|
+
int8x16_t b_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_low_u8x16), 4), 4);
|
|
316
|
+
int8x16_t b_high_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_high_u8x16), 4), 4);
|
|
317
|
+
|
|
318
|
+
// SDOT for signed dot product - no correction needed!
|
|
319
|
+
int32x4_t product_sum_i32x4 = state->product_sum_i32x4;
|
|
320
|
+
product_sum_i32x4 = vdotq_s32(product_sum_i32x4, a_low_i8x16, b_low_i8x16);
|
|
321
|
+
product_sum_i32x4 = vdotq_s32(product_sum_i32x4, a_high_i8x16, b_high_i8x16);
|
|
322
|
+
state->product_sum_i32x4 = product_sum_i32x4;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
NK_INTERNAL void nk_dot_i4x32_finalize_neonsdot( //
|
|
326
|
+
nk_dot_i4x32_state_neonsdot_t const *state_a, nk_dot_i4x32_state_neonsdot_t const *state_b, //
|
|
327
|
+
nk_dot_i4x32_state_neonsdot_t const *state_c, nk_dot_i4x32_state_neonsdot_t const *state_d, //
|
|
328
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
329
|
+
nk_unused_(total_dimensions);
|
|
330
|
+
// Simple reduction - no correction formula needed with sign-extension approach!
|
|
331
|
+
result->i32s[0] = vaddvq_s32(state_a->product_sum_i32x4);
|
|
332
|
+
result->i32s[1] = vaddvq_s32(state_b->product_sum_i32x4);
|
|
333
|
+
result->i32s[2] = vaddvq_s32(state_c->product_sum_i32x4);
|
|
334
|
+
result->i32s[3] = vaddvq_s32(state_d->product_sum_i32x4);
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
typedef struct nk_dot_u4x32_state_neonsdot_t {
|
|
338
|
+
uint32x4_t product_sum_u32x4;
|
|
339
|
+
} nk_dot_u4x32_state_neonsdot_t;
|
|
340
|
+
|
|
341
|
+
NK_INTERNAL void nk_dot_u4x32_init_neonsdot(nk_dot_u4x32_state_neonsdot_t *state) {
|
|
342
|
+
state->product_sum_u32x4 = vdupq_n_u32(0);
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
NK_INTERNAL void nk_dot_u4x32_update_neonsdot(nk_dot_u4x32_state_neonsdot_t *state, nk_b128_vec_t a_u4x32,
|
|
346
|
+
nk_b128_vec_t b_u4x32, nk_size_t depth_offset,
|
|
347
|
+
nk_size_t active_dimensions) {
|
|
348
|
+
nk_unused_(depth_offset);
|
|
349
|
+
nk_unused_(active_dimensions);
|
|
350
|
+
uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
|
|
351
|
+
|
|
352
|
+
// Extract nibbles - values in [0,15] work directly with UDOT
|
|
353
|
+
uint8x16_t a_low_u8x16 = vandq_u8(a_u4x32.u8x16, nibble_mask_u8x16);
|
|
354
|
+
uint8x16_t a_high_u8x16 = vshrq_n_u8(a_u4x32.u8x16, 4);
|
|
355
|
+
uint8x16_t b_low_u8x16 = vandq_u8(b_u4x32.u8x16, nibble_mask_u8x16);
|
|
356
|
+
uint8x16_t b_high_u8x16 = vshrq_n_u8(b_u4x32.u8x16, 4);
|
|
357
|
+
|
|
358
|
+
// UDOT directly on unsigned nibbles
|
|
359
|
+
uint32x4_t product_sum_u32x4 = state->product_sum_u32x4;
|
|
360
|
+
product_sum_u32x4 = vdotq_u32(product_sum_u32x4, a_low_u8x16, b_low_u8x16);
|
|
361
|
+
product_sum_u32x4 = vdotq_u32(product_sum_u32x4, a_high_u8x16, b_high_u8x16);
|
|
362
|
+
state->product_sum_u32x4 = product_sum_u32x4;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
NK_INTERNAL void nk_dot_u4x32_finalize_neonsdot( //
|
|
366
|
+
nk_dot_u4x32_state_neonsdot_t const *state_a, nk_dot_u4x32_state_neonsdot_t const *state_b, //
|
|
367
|
+
nk_dot_u4x32_state_neonsdot_t const *state_c, nk_dot_u4x32_state_neonsdot_t const *state_d, //
|
|
368
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
369
|
+
nk_unused_(total_dimensions);
|
|
370
|
+
// Simple reduction - no correction formula needed!
|
|
371
|
+
result->u32s[0] = vaddvq_u32(state_a->product_sum_u32x4);
|
|
372
|
+
result->u32s[1] = vaddvq_u32(state_b->product_sum_u32x4);
|
|
373
|
+
result->u32s[2] = vaddvq_u32(state_c->product_sum_u32x4);
|
|
374
|
+
result->u32s[3] = vaddvq_u32(state_d->product_sum_u32x4);
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
NK_PUBLIC void nk_dot_e2m3_neonsdot(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
378
|
+
nk_f32_t *result) {
|
|
379
|
+
// Integer dot product for e2m3 using SDOT (signed×signed i8 → i32).
|
|
380
|
+
// Every e2m3 value × 16 is an exact integer in [-120, +120], fits signed i8.
|
|
381
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
382
|
+
//
|
|
383
|
+
// 32-entry LUT via vqtbl2q_u8 (handles 0-31 indices in one instruction).
|
|
384
|
+
static nk_u8_t const lut_data[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
|
|
385
|
+
32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
|
|
386
|
+
uint8x16x2_t lut_magnitude_u8x16x2 = vld1q_u8_x2(lut_data);
|
|
387
|
+
uint8x16_t magnitude_mask_u8x16 = vdupq_n_u8(0x1F);
|
|
388
|
+
uint8x16_t sign_mask_u8x16 = vdupq_n_u8(0x20);
|
|
389
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
390
|
+
uint8x16_t a_e2m3_u8x16, b_e2m3_u8x16;
|
|
391
|
+
|
|
392
|
+
nk_dot_e2m3_neonsdot_cycle:
|
|
393
|
+
if (count_scalars < 16) {
|
|
394
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
395
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
396
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
397
|
+
a_e2m3_u8x16 = a_vec.u8x16;
|
|
398
|
+
b_e2m3_u8x16 = b_vec.u8x16;
|
|
399
|
+
count_scalars = 0;
|
|
400
|
+
}
|
|
401
|
+
else {
|
|
402
|
+
a_e2m3_u8x16 = vld1q_u8((nk_u8_t const *)a_scalars);
|
|
403
|
+
b_e2m3_u8x16 = vld1q_u8((nk_u8_t const *)b_scalars);
|
|
404
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
// Extract 5-bit magnitude indices and LUT lookup
|
|
408
|
+
uint8x16_t a_magnitude_u8x16 = vandq_u8(a_e2m3_u8x16, magnitude_mask_u8x16);
|
|
409
|
+
uint8x16_t b_magnitude_u8x16 = vandq_u8(b_e2m3_u8x16, magnitude_mask_u8x16);
|
|
410
|
+
uint8x16_t a_unsigned_u8x16 = vqtbl2q_u8(lut_magnitude_u8x16x2, a_magnitude_u8x16);
|
|
411
|
+
uint8x16_t b_unsigned_u8x16 = vqtbl2q_u8(lut_magnitude_u8x16x2, b_magnitude_u8x16);
|
|
412
|
+
|
|
413
|
+
// Combined sign: (a ^ b) & 0x20 — nonzero means negative product
|
|
414
|
+
uint8x16_t sign_combined_u8x16 = vandq_u8(veorq_u8(a_e2m3_u8x16, b_e2m3_u8x16), sign_mask_u8x16);
|
|
415
|
+
uint8x16_t negate_mask_u8x16 = vceqq_u8(sign_combined_u8x16, sign_mask_u8x16);
|
|
416
|
+
|
|
417
|
+
// Negate b where signs differ, keep positive otherwise
|
|
418
|
+
int8x16_t b_signed_i8x16 = vbslq_s8(negate_mask_u8x16, vnegq_s8(vreinterpretq_s8_u8(b_unsigned_u8x16)),
|
|
419
|
+
vreinterpretq_s8_u8(b_unsigned_u8x16));
|
|
420
|
+
|
|
421
|
+
// SDOT: signed×signed, 4 bytes → i32
|
|
422
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, vreinterpretq_s8_u8(a_unsigned_u8x16), b_signed_i8x16);
|
|
423
|
+
|
|
424
|
+
if (count_scalars) goto nk_dot_e2m3_neonsdot_cycle;
|
|
425
|
+
*result = (nk_f32_t)vaddvq_s32(sum_i32x4) / 256.0f;
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
NK_PUBLIC void nk_dot_e3m2_neonsdot(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
429
|
+
nk_f32_t *result) {
|
|
430
|
+
// Integer dot product for e3m2 using i16 LUT via vqtbl2q_u8 (low bytes) + comparison (high byte) + SMLAL.
|
|
431
|
+
// Every e3m2 value × 16 is an exact integer, but magnitudes reach 448, requiring i16.
|
|
432
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
433
|
+
//
|
|
434
|
+
// The 32-entry magnitude LUT low bytes are looked up via vqtbl2q_u8.
|
|
435
|
+
// High byte is 1 only for indices 28-31 (values 256-448), replaced by a >= 28 comparison.
|
|
436
|
+
static nk_u8_t const lut_data[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28,
|
|
437
|
+
32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 0, 64, 128, 192};
|
|
438
|
+
uint8x16x2_t lut = vld1q_u8_x2(lut_data);
|
|
439
|
+
uint8x16_t high_threshold_u8x16 = vdupq_n_u8(28);
|
|
440
|
+
uint8x16_t magnitude_mask_u8x16 = vdupq_n_u8(0x1F);
|
|
441
|
+
uint8x16_t sign_mask_u8x16 = vdupq_n_u8(0x20);
|
|
442
|
+
int32x4_t sum0_i32x4 = vdupq_n_s32(0);
|
|
443
|
+
int32x4_t sum1_i32x4 = vdupq_n_s32(0);
|
|
444
|
+
uint8x16_t a_e3m2_u8x16, b_e3m2_u8x16;
|
|
445
|
+
|
|
446
|
+
nk_dot_e3m2_neonsdot_cycle:
|
|
447
|
+
if (count_scalars < 16) {
|
|
448
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
449
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
450
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
451
|
+
a_e3m2_u8x16 = a_vec.u8x16;
|
|
452
|
+
b_e3m2_u8x16 = b_vec.u8x16;
|
|
453
|
+
count_scalars = 0;
|
|
454
|
+
}
|
|
455
|
+
else {
|
|
456
|
+
a_e3m2_u8x16 = vld1q_u8((nk_u8_t const *)a_scalars);
|
|
457
|
+
b_e3m2_u8x16 = vld1q_u8((nk_u8_t const *)b_scalars);
|
|
458
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
// Extract 5-bit magnitude indices
|
|
462
|
+
uint8x16_t a_mag_u8x16 = vandq_u8(a_e3m2_u8x16, magnitude_mask_u8x16);
|
|
463
|
+
uint8x16_t b_mag_u8x16 = vandq_u8(b_e3m2_u8x16, magnitude_mask_u8x16);
|
|
464
|
+
|
|
465
|
+
// LUT lookup for low bytes; high byte via comparison (1 iff index >= 28)
|
|
466
|
+
uint8x16_t a_lo_u8x16 = vqtbl2q_u8(lut, a_mag_u8x16);
|
|
467
|
+
uint8x16_t b_lo_u8x16 = vqtbl2q_u8(lut, b_mag_u8x16);
|
|
468
|
+
uint8x16_t a_hi_u8x16 = vandq_u8(vcgeq_u8(a_mag_u8x16, high_threshold_u8x16), vdupq_n_u8(1));
|
|
469
|
+
uint8x16_t b_hi_u8x16 = vandq_u8(vcgeq_u8(b_mag_u8x16, high_threshold_u8x16), vdupq_n_u8(1));
|
|
470
|
+
|
|
471
|
+
// Combine low and high bytes into i16 via byte interleave (little-endian: low byte first)
|
|
472
|
+
int16x8_t a_unsigned_low_i16x8 = vreinterpretq_s16_u8(vzip1q_u8(a_lo_u8x16, a_hi_u8x16));
|
|
473
|
+
int16x8_t a_unsigned_high_i16x8 = vreinterpretq_s16_u8(vzip2q_u8(a_lo_u8x16, a_hi_u8x16));
|
|
474
|
+
int16x8_t b_unsigned_low_i16x8 = vreinterpretq_s16_u8(vzip1q_u8(b_lo_u8x16, b_hi_u8x16));
|
|
475
|
+
int16x8_t b_unsigned_high_i16x8 = vreinterpretq_s16_u8(vzip2q_u8(b_lo_u8x16, b_hi_u8x16));
|
|
476
|
+
|
|
477
|
+
// Combined sign: XOR sign bits, negate only b (saves ~15 ops vs independent negation)
|
|
478
|
+
uint8x16_t sign_combined_u8x16 = vandq_u8(veorq_u8(a_e3m2_u8x16, b_e3m2_u8x16), sign_mask_u8x16);
|
|
479
|
+
uint8x16_t negate_mask_u8x16 = vceqq_u8(sign_combined_u8x16, sign_mask_u8x16);
|
|
480
|
+
uint16x8_t negate_low_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(negate_mask_u8x16, negate_mask_u8x16));
|
|
481
|
+
uint16x8_t negate_high_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(negate_mask_u8x16, negate_mask_u8x16));
|
|
482
|
+
b_unsigned_low_i16x8 = vbslq_s16(negate_low_u16x8, vnegq_s16(b_unsigned_low_i16x8), b_unsigned_low_i16x8);
|
|
483
|
+
b_unsigned_high_i16x8 = vbslq_s16(negate_high_u16x8, vnegq_s16(b_unsigned_high_i16x8), b_unsigned_high_i16x8);
|
|
484
|
+
|
|
485
|
+
// Widening multiply-accumulate: i16×i16 → i32
|
|
486
|
+
sum0_i32x4 = vmlal_s16(sum0_i32x4, vget_low_s16(a_unsigned_low_i16x8), vget_low_s16(b_unsigned_low_i16x8));
|
|
487
|
+
sum0_i32x4 = vmlal_high_s16(sum0_i32x4, a_unsigned_low_i16x8, b_unsigned_low_i16x8);
|
|
488
|
+
sum1_i32x4 = vmlal_s16(sum1_i32x4, vget_low_s16(a_unsigned_high_i16x8), vget_low_s16(b_unsigned_high_i16x8));
|
|
489
|
+
sum1_i32x4 = vmlal_high_s16(sum1_i32x4, a_unsigned_high_i16x8, b_unsigned_high_i16x8);
|
|
490
|
+
|
|
491
|
+
if (count_scalars) goto nk_dot_e3m2_neonsdot_cycle;
|
|
492
|
+
int32x4_t total_i32x4 = vaddq_s32(sum0_i32x4, sum1_i32x4);
|
|
493
|
+
*result = (nk_f32_t)vaddvq_s32(total_i32x4) / 256.0f;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
#if defined(__clang__)
|
|
497
|
+
#pragma clang attribute pop
|
|
498
|
+
#elif defined(__GNUC__)
|
|
499
|
+
#pragma GCC pop_options
|
|
500
|
+
#endif
|
|
501
|
+
|
|
502
|
+
#if defined(__cplusplus)
|
|
503
|
+
} // extern "C"
|
|
504
|
+
#endif
|
|
505
|
+
|
|
506
|
+
#endif // NK_TARGET_NEONSDOT
|
|
507
|
+
#endif // NK_TARGET_ARM_
|
|
508
|
+
#endif // NK_DOT_NEONSDOT_H
|