numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,838 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated Dot Products for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/dot/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_serial_instructions Serial Fallback Implementation
|
|
10
|
+
*
|
|
11
|
+
* The serial backend provides portable scalar implementations for all numeric types without requiring
|
|
12
|
+
* any SIMD extensions. While significantly slower than vectorized implementations, these serve as:
|
|
13
|
+
*
|
|
14
|
+
* - Reference implementations for correctness validation
|
|
15
|
+
* - Fallbacks for platforms without SIMD support (WASM, older CPUs)
|
|
16
|
+
* - Baseline for benchmarking vectorized speedups
|
|
17
|
+
*
|
|
18
|
+
* For f64 dot products, compensated (Kahan-style) summation is used to minimize floating-point
|
|
19
|
+
* accumulation errors. For smaller types (f16, bf16, FP8), values are upcast to f32 for accumulation.
|
|
20
|
+
*
|
|
21
|
+
* @section dot_serial_stateful Stateful Streaming Logic
|
|
22
|
+
*
|
|
23
|
+
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
24
|
+
* `NK_INTERNAL` functions:
|
|
25
|
+
*
|
|
26
|
+
* - nk_dot_f64x2 state with compensated summation for numerical stability,
|
|
27
|
+
* - nk_dot_f32x4 state with simple f32 accumulation,
|
|
28
|
+
* - nk_dot_f16x8 state for f16 inputs via f32 upcasting,
|
|
29
|
+
* - nk_dot_bf16x8 state for bf16 inputs via f32 upcasting,
|
|
30
|
+
* - nk_dot_i8x16 for 8-bit signed integer inputs,
|
|
31
|
+
* - nk_dot_u8x16 for 8-bit unsigned integer inputs,
|
|
32
|
+
* - nk_dot_e4m3x16, nk_dot_e5m2x16, nk_dot_e2m3x16, nk_dot_e3m2x16 for FP8/FP6 inputs,
|
|
33
|
+
* - nk_dot_i4x16, nk_dot_u4x16 for 4-bit integer inputs.
|
|
34
|
+
*
|
|
35
|
+
* @code{c}
|
|
36
|
+
* nk_dot_f64x2_state_serial_t state_first, state_second, state_third, state_fourth;
|
|
37
|
+
* nk_b128_vec_t query_f64x2, target_first_f64x2, target_second_f64x2, target_third_f64x2, target_fourth_f64x2;
|
|
38
|
+
* nk_dot_f64x2_init_serial(&state_first);
|
|
39
|
+
* nk_dot_f64x2_init_serial(&state_second);
|
|
40
|
+
* nk_dot_f64x2_init_serial(&state_third);
|
|
41
|
+
* nk_dot_f64x2_init_serial(&state_fourth);
|
|
42
|
+
* for (nk_size_t idx = 0; idx + 2 <= depth; idx += 2) {
|
|
43
|
+
* query_f64x2.f64s[0] = query_ptr[idx], query_f64x2.f64s[1] = query_ptr[idx + 1];
|
|
44
|
+
* target_first_f64x2.f64s[0] = target_first_ptr[idx], target_first_f64x2.f64s[1] = target_first_ptr[idx + 1];
|
|
45
|
+
* target_second_f64x2.f64s[0] = target_second_ptr[idx], target_second_f64x2.f64s[1] = target_second_ptr[idx + 1];
|
|
46
|
+
* target_third_f64x2.f64s[0] = target_third_ptr[idx], target_third_f64x2.f64s[1] = target_third_ptr[idx + 1];
|
|
47
|
+
* target_fourth_f64x2.f64s[0] = target_fourth_ptr[idx], target_fourth_f64x2.f64s[1] = target_fourth_ptr[idx + 1];
|
|
48
|
+
* nk_dot_f64x2_update_serial(&state_first, query_f64x2, target_first_f64x2, idx, 2);
|
|
49
|
+
* nk_dot_f64x2_update_serial(&state_second, query_f64x2, target_second_f64x2, idx, 2);
|
|
50
|
+
* nk_dot_f64x2_update_serial(&state_third, query_f64x2, target_third_f64x2, idx, 2);
|
|
51
|
+
* nk_dot_f64x2_update_serial(&state_fourth, query_f64x2, target_fourth_f64x2, idx, 2);
|
|
52
|
+
* }
|
|
53
|
+
* nk_b256_vec_t results_f64x4;
|
|
54
|
+
* nk_dot_f64x2_finalize_serial(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f64x4);
|
|
55
|
+
* @endcode
|
|
56
|
+
*
|
|
57
|
+
* Integer types follow a similar pattern with appropriate type changes:
|
|
58
|
+
*
|
|
59
|
+
* @code{c}
|
|
60
|
+
* nk_dot_i8x16_state_serial_t state_first, state_second, state_third, state_fourth;
|
|
61
|
+
* nk_b128_vec_t query_i8x16, target_first_i8x16, target_second_i8x16, target_third_i8x16, target_fourth_i8x16;
|
|
62
|
+
* nk_dot_i8x16_init_serial(&state_first);
|
|
63
|
+
* nk_dot_i8x16_init_serial(&state_second);
|
|
64
|
+
* nk_dot_i8x16_init_serial(&state_third);
|
|
65
|
+
* nk_dot_i8x16_init_serial(&state_fourth);
|
|
66
|
+
* for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
|
|
67
|
+
* memcpy(query_i8x16.i8s, query_ptr + idx, 16);
|
|
68
|
+
* memcpy(target_first_i8x16.i8s, target_first_ptr + idx, 16);
|
|
69
|
+
* memcpy(target_second_i8x16.i8s, target_second_ptr + idx, 16);
|
|
70
|
+
* memcpy(target_third_i8x16.i8s, target_third_ptr + idx, 16);
|
|
71
|
+
* memcpy(target_fourth_i8x16.i8s, target_fourth_ptr + idx, 16);
|
|
72
|
+
* nk_dot_i8x16_update_serial(&state_first, query_i8x16, target_first_i8x16, idx, 16);
|
|
73
|
+
* nk_dot_i8x16_update_serial(&state_second, query_i8x16, target_second_i8x16, idx, 16);
|
|
74
|
+
* nk_dot_i8x16_update_serial(&state_third, query_i8x16, target_third_i8x16, idx, 16);
|
|
75
|
+
* nk_dot_i8x16_update_serial(&state_fourth, query_i8x16, target_fourth_i8x16, idx, 16);
|
|
76
|
+
* }
|
|
77
|
+
* nk_b128_vec_t results_i32x4;
|
|
78
|
+
* nk_dot_i8x16_finalize_serial(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
|
|
79
|
+
* @endcode
|
|
80
|
+
*/
|
|
81
|
+
#ifndef NK_DOT_SERIAL_H
|
|
82
|
+
#define NK_DOT_SERIAL_H
|
|
83
|
+
|
|
84
|
+
#include "numkong/types.h"
|
|
85
|
+
#include "numkong/reduce/serial.h" // `nk_f64_abs_`
|
|
86
|
+
|
|
87
|
+
#if defined(__cplusplus)
|
|
88
|
+
extern "C" {
|
|
89
|
+
#endif
|
|
90
|
+
|
|
91
|
+
/**
|
|
92
|
+
* @brief Macro for dot product with simple accumulation.
|
|
93
|
+
*/
|
|
94
|
+
#define nk_define_dot_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
95
|
+
NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
96
|
+
nk_size_t n, nk_##output_type##_t *result) { \
|
|
97
|
+
nk_##accumulator_type##_t sum = 0, a_val, b_val; \
|
|
98
|
+
for (nk_size_t i = 0; i != n; ++i) { \
|
|
99
|
+
load_and_convert(a + i, &a_val); \
|
|
100
|
+
load_and_convert(b + i, &b_val); \
|
|
101
|
+
sum += a_val * b_val; \
|
|
102
|
+
} \
|
|
103
|
+
*result = (nk_##output_type##_t)sum; \
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
#define nk_define_dot_complex_(input_type, accumulator_type, output_complex_type, load_and_convert) \
|
|
107
|
+
NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a_pairs, \
|
|
108
|
+
nk_##input_type##_t const *b_pairs, nk_size_t count_pairs, \
|
|
109
|
+
nk_##output_complex_type##_t *result) { \
|
|
110
|
+
nk_##accumulator_type##_t sum_real = 0, sum_imag = 0; \
|
|
111
|
+
nk_##accumulator_type##_t a_real, b_real, a_imag, b_imag; \
|
|
112
|
+
for (nk_size_t i = 0; i != count_pairs; ++i) { \
|
|
113
|
+
load_and_convert(&(a_pairs + i)->real, &a_real); \
|
|
114
|
+
load_and_convert(&(b_pairs + i)->real, &b_real); \
|
|
115
|
+
load_and_convert(&(a_pairs + i)->imag, &a_imag); \
|
|
116
|
+
load_and_convert(&(b_pairs + i)->imag, &b_imag); \
|
|
117
|
+
sum_real += a_real * b_real - a_imag * b_imag; \
|
|
118
|
+
sum_imag += a_real * b_imag + a_imag * b_real; \
|
|
119
|
+
} \
|
|
120
|
+
result->real = sum_real; \
|
|
121
|
+
result->imag = sum_imag; \
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
#define nk_define_vdot_complex_(input_type, accumulator_type, output_complex_type, load_and_convert) \
|
|
125
|
+
NK_PUBLIC void nk_vdot_##input_type##_serial(nk_##input_type##_t const *a_pairs, \
|
|
126
|
+
nk_##input_type##_t const *b_pairs, nk_size_t count_pairs, \
|
|
127
|
+
nk_##output_complex_type##_t *result) { \
|
|
128
|
+
nk_##accumulator_type##_t sum_real = 0, sum_imag = 0; \
|
|
129
|
+
nk_##accumulator_type##_t a_real, b_real, a_imag, b_imag; \
|
|
130
|
+
for (nk_size_t i = 0; i != count_pairs; ++i) { \
|
|
131
|
+
load_and_convert(&(a_pairs + i)->real, &a_real); \
|
|
132
|
+
load_and_convert(&(b_pairs + i)->real, &b_real); \
|
|
133
|
+
load_and_convert(&(a_pairs + i)->imag, &a_imag); \
|
|
134
|
+
load_and_convert(&(b_pairs + i)->imag, &b_imag); \
|
|
135
|
+
sum_real += a_real * b_real + a_imag * b_imag; \
|
|
136
|
+
sum_imag += a_real * b_imag - a_imag * b_real; \
|
|
137
|
+
} \
|
|
138
|
+
result->real = sum_real; \
|
|
139
|
+
result->imag = sum_imag; \
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
#pragma region - Traditional Floats
|
|
143
|
+
|
|
144
|
+
nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
|
|
145
|
+
nk_define_dot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_dot_f32c_serial
|
|
146
|
+
nk_define_vdot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_vdot_f32c_serial
|
|
147
|
+
|
|
148
|
+
#pragma endregion - Traditional Floats
|
|
149
|
+
|
|
150
|
+
#pragma region - Smaller Floats
|
|
151
|
+
|
|
152
|
+
nk_define_dot_(f16, f32, f32, nk_f16_to_f32_serial) // nk_dot_f16_serial
|
|
153
|
+
nk_define_dot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_dot_f16c_serial
|
|
154
|
+
nk_define_vdot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_vdot_f16c_serial
|
|
155
|
+
|
|
156
|
+
nk_define_dot_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_dot_bf16_serial
|
|
157
|
+
nk_define_dot_complex_(bf16c, f32, f32c, nk_bf16_to_f32_serial) // nk_dot_bf16c_serial
|
|
158
|
+
nk_define_vdot_complex_(bf16c, f32, f32c, nk_bf16_to_f32_serial) // nk_vdot_bf16c_serial
|
|
159
|
+
|
|
160
|
+
nk_define_dot_(e4m3, f32, f32, nk_e4m3_to_f32_serial) // nk_dot_e4m3_serial
|
|
161
|
+
nk_define_dot_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_dot_e5m2_serial
|
|
162
|
+
nk_define_dot_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_dot_e2m3_serial
|
|
163
|
+
nk_define_dot_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_dot_e3m2_serial
|
|
164
|
+
|
|
165
|
+
#pragma endregion - Smaller Floats
|
|
166
|
+
|
|
167
|
+
#pragma region - Small Integers
|
|
168
|
+
|
|
169
|
+
nk_define_dot_(i8, i32, i32, nk_assign_from_to_) // nk_dot_i8_serial
|
|
170
|
+
nk_define_dot_(u8, u32, u32, nk_assign_from_to_) // nk_dot_u8_serial
|
|
171
|
+
|
|
172
|
+
#undef nk_define_dot_
|
|
173
|
+
#undef nk_define_dot_complex_
|
|
174
|
+
#undef nk_define_vdot_complex_
|
|
175
|
+
|
|
176
|
+
NK_PUBLIC void nk_dot_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
177
|
+
// i4 values are packed as nibbles: two 4-bit signed values per byte.
|
|
178
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
179
|
+
// Sign extension: (nibble ^ 8) - 8 maps [0,15] to [-8,7]
|
|
180
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
181
|
+
nk_size_t n_bytes = n / 2;
|
|
182
|
+
nk_i32_t sum = 0;
|
|
183
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
184
|
+
nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
|
|
185
|
+
nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
|
|
186
|
+
nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
|
|
187
|
+
nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
|
|
188
|
+
sum += a_low * b_low + a_high * b_high;
|
|
189
|
+
}
|
|
190
|
+
*result = sum;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
194
|
+
// u4 values are packed as nibbles: two 4-bit unsigned values per byte.
|
|
195
|
+
// Parameter `n` is the number of 4-bit values (dimensions), not bytes.
|
|
196
|
+
// No sign extension needed - values are ∈ [0,15].
|
|
197
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
198
|
+
nk_size_t n_bytes = n / 2;
|
|
199
|
+
nk_u32_t sum = 0;
|
|
200
|
+
for (nk_size_t i = 0; i < n_bytes; ++i) {
|
|
201
|
+
nk_u32_t a_low = (nk_u32_t)nk_u4x2_low_(a[i]);
|
|
202
|
+
nk_u32_t b_low = (nk_u32_t)nk_u4x2_low_(b[i]);
|
|
203
|
+
nk_u32_t a_high = (nk_u32_t)nk_u4x2_high_(a[i]);
|
|
204
|
+
nk_u32_t b_high = (nk_u32_t)nk_u4x2_high_(b[i]);
|
|
205
|
+
sum += a_low * b_low + a_high * b_high;
|
|
206
|
+
}
|
|
207
|
+
*result = sum;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
#pragma endregion - Small Integers
|
|
211
|
+
|
|
212
|
+
#pragma region - Traditional Floats
|
|
213
|
+
|
|
214
|
+
/* Double-precision dot-produce variants
|
|
215
|
+
*
|
|
216
|
+
* Implements Neumaier's Kahan-Babuška variant to minimize floating-point rounding errors.
|
|
217
|
+
* Unlike Kahan, Neumaier handles the case where the term being added is larger than the
|
|
218
|
+
* running sum. Achieves O(1) error growth regardless of vector dimension.
|
|
219
|
+
*
|
|
220
|
+
* Algorithm: For each term, compute t = sum + term, then:
|
|
221
|
+
* - If ‖sum‖ ≥ ‖term‖: c += (sum - t) + term (lost low-order bits of term)
|
|
222
|
+
* - Else: c += (term - t) + sum (lost low-order bits of sum)
|
|
223
|
+
*
|
|
224
|
+
* @see Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
|
|
225
|
+
*/
|
|
226
|
+
NK_PUBLIC void nk_dot_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
227
|
+
nk_f64_t sum = 0, compensation = 0;
|
|
228
|
+
for (nk_size_t i = 0; i != n; ++i) nk_f64_dot2_(&sum, &compensation, a[i], b[i]);
|
|
229
|
+
*result = sum + compensation;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
NK_PUBLIC void nk_dot_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
233
|
+
nk_f64c_t *result) {
|
|
234
|
+
nk_f64_t sum_real = 0, sum_imag = 0, compensation_real = 0, compensation_imag = 0;
|
|
235
|
+
for (nk_size_t i = 0; i != count_pairs; ++i) {
|
|
236
|
+
nk_f64_t a_real = a_pairs[i].real, b_real = b_pairs[i].real;
|
|
237
|
+
nk_f64_t a_imag = a_pairs[i].imag, b_imag = b_pairs[i].imag;
|
|
238
|
+
nk_f64_dot2_(&sum_real, &compensation_real, a_real, b_real);
|
|
239
|
+
nk_f64_dot2_(&sum_real, &compensation_real, -a_imag, b_imag);
|
|
240
|
+
nk_f64_dot2_(&sum_imag, &compensation_imag, a_real, b_imag);
|
|
241
|
+
nk_f64_dot2_(&sum_imag, &compensation_imag, a_imag, b_real);
|
|
242
|
+
}
|
|
243
|
+
result->real = sum_real + compensation_real;
|
|
244
|
+
result->imag = sum_imag + compensation_imag;
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
NK_PUBLIC void nk_vdot_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
248
|
+
nk_f64c_t *result) {
|
|
249
|
+
nk_f64_t sum_real = 0, sum_imag = 0, compensation_real = 0, compensation_imag = 0;
|
|
250
|
+
for (nk_size_t i = 0; i != count_pairs; ++i) {
|
|
251
|
+
nk_f64_t a_real = a_pairs[i].real, b_real = b_pairs[i].real;
|
|
252
|
+
nk_f64_t a_imag = a_pairs[i].imag, b_imag = b_pairs[i].imag;
|
|
253
|
+
nk_f64_dot2_(&sum_real, &compensation_real, a_real, b_real);
|
|
254
|
+
nk_f64_dot2_(&sum_real, &compensation_real, a_imag, b_imag);
|
|
255
|
+
nk_f64_dot2_(&sum_imag, &compensation_imag, a_real, b_imag);
|
|
256
|
+
nk_f64_dot2_(&sum_imag, &compensation_imag, -a_imag, b_real);
|
|
257
|
+
}
|
|
258
|
+
result->real = sum_real + compensation_real;
|
|
259
|
+
result->imag = sum_imag + compensation_imag;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
typedef struct nk_dot_f64x2_state_serial_t {
|
|
263
|
+
nk_f64_t sums[2];
|
|
264
|
+
nk_f64_t compensations[2];
|
|
265
|
+
} nk_dot_f64x2_state_serial_t;
|
|
266
|
+
|
|
267
|
+
NK_INTERNAL void nk_dot_f64x2_init_serial(nk_dot_f64x2_state_serial_t *state) {
|
|
268
|
+
state->sums[0] = 0, state->sums[1] = 0;
|
|
269
|
+
state->compensations[0] = 0, state->compensations[1] = 0;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
NK_INTERNAL void nk_dot_f64x2_update_serial(nk_dot_f64x2_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
273
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
274
|
+
nk_unused_(depth_offset);
|
|
275
|
+
nk_unused_(active_dimensions);
|
|
276
|
+
nk_f64_t sum0 = state->sums[0], compensation0 = state->compensations[0];
|
|
277
|
+
nk_f64_t sum1 = state->sums[1], compensation1 = state->compensations[1];
|
|
278
|
+
nk_f64_dot2_(&sum0, &compensation0, a.f64s[0], b.f64s[0]);
|
|
279
|
+
nk_f64_dot2_(&sum1, &compensation1, a.f64s[1], b.f64s[1]);
|
|
280
|
+
|
|
281
|
+
state->sums[0] = sum0, state->sums[1] = sum1;
|
|
282
|
+
state->compensations[0] = compensation0, state->compensations[1] = compensation1;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
NK_INTERNAL void nk_dot_f64x2_finalize_serial( //
|
|
286
|
+
nk_dot_f64x2_state_serial_t const *state_a, nk_dot_f64x2_state_serial_t const *state_b, //
|
|
287
|
+
nk_dot_f64x2_state_serial_t const *state_c, nk_dot_f64x2_state_serial_t const *state_d, //
|
|
288
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
289
|
+
nk_unused_(total_dimensions);
|
|
290
|
+
result->f64s[0] = nk_reduce_sum_f64_serial_(state_a->sums, state_a->compensations, 2);
|
|
291
|
+
result->f64s[1] = nk_reduce_sum_f64_serial_(state_b->sums, state_b->compensations, 2);
|
|
292
|
+
result->f64s[2] = nk_reduce_sum_f64_serial_(state_c->sums, state_c->compensations, 2);
|
|
293
|
+
result->f64s[3] = nk_reduce_sum_f64_serial_(state_d->sums, state_d->compensations, 2);
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
typedef struct nk_dot_f32x4_state_serial_t {
|
|
297
|
+
nk_f64_t sums[4];
|
|
298
|
+
} nk_dot_f32x4_state_serial_t;
|
|
299
|
+
|
|
300
|
+
NK_INTERNAL void nk_dot_f32x4_init_serial(nk_dot_f32x4_state_serial_t *state) {
|
|
301
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
NK_INTERNAL void nk_dot_f32x4_update_serial(nk_dot_f32x4_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
305
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
306
|
+
nk_unused_(depth_offset);
|
|
307
|
+
nk_unused_(active_dimensions);
|
|
308
|
+
nk_f64_t sum0 = state->sums[0];
|
|
309
|
+
nk_f64_t sum1 = state->sums[1];
|
|
310
|
+
nk_f64_t sum2 = state->sums[2];
|
|
311
|
+
nk_f64_t sum3 = state->sums[3];
|
|
312
|
+
sum0 += (nk_f64_t)a.f32s[0] * b.f32s[0], sum1 += (nk_f64_t)a.f32s[1] * b.f32s[1];
|
|
313
|
+
sum2 += (nk_f64_t)a.f32s[2] * b.f32s[2], sum3 += (nk_f64_t)a.f32s[3] * b.f32s[3];
|
|
314
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
NK_INTERNAL void nk_dot_f32x4_finalize_serial( //
|
|
318
|
+
nk_dot_f32x4_state_serial_t const *state_a, nk_dot_f32x4_state_serial_t const *state_b, //
|
|
319
|
+
nk_dot_f32x4_state_serial_t const *state_c, nk_dot_f32x4_state_serial_t const *state_d, //
|
|
320
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
321
|
+
nk_unused_(total_dimensions);
|
|
322
|
+
result->f64s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
323
|
+
result->f64s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
324
|
+
result->f64s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
325
|
+
result->f64s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
#pragma endregion - Traditional Floats
|
|
329
|
+
|
|
330
|
+
#pragma region - Smaller Floats
|
|
331
|
+
|
|
332
|
+
typedef struct nk_dot_f16x8_state_serial_t {
|
|
333
|
+
nk_f32_t sums[4];
|
|
334
|
+
} nk_dot_f16x8_state_serial_t;
|
|
335
|
+
|
|
336
|
+
NK_INTERNAL void nk_dot_f16x8_init_serial(nk_dot_f16x8_state_serial_t *state) {
|
|
337
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
NK_INTERNAL void nk_dot_f16x8_update_serial(nk_dot_f16x8_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
341
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
342
|
+
nk_unused_(depth_offset);
|
|
343
|
+
nk_unused_(active_dimensions);
|
|
344
|
+
nk_f32_t sum0 = state->sums[0], sum1 = state->sums[1], sum2 = state->sums[2], sum3 = state->sums[3];
|
|
345
|
+
for (nk_size_t i = 0; i < 8; i += 4) {
|
|
346
|
+
nk_f32_t a0, a1, a2, a3, b0, b1, b2, b3;
|
|
347
|
+
nk_f16_to_f32_serial(a.f16s + i + 0, &a0), nk_f16_to_f32_serial(a.f16s + i + 1, &a1);
|
|
348
|
+
nk_f16_to_f32_serial(a.f16s + i + 2, &a2), nk_f16_to_f32_serial(a.f16s + i + 3, &a3);
|
|
349
|
+
nk_f16_to_f32_serial(b.f16s + i + 0, &b0), nk_f16_to_f32_serial(b.f16s + i + 1, &b1);
|
|
350
|
+
nk_f16_to_f32_serial(b.f16s + i + 2, &b2), nk_f16_to_f32_serial(b.f16s + i + 3, &b3);
|
|
351
|
+
sum0 += a0 * b0, sum1 += a1 * b1, sum2 += a2 * b2, sum3 += a3 * b3;
|
|
352
|
+
}
|
|
353
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
NK_INTERNAL void nk_dot_f16x8_finalize_serial( //
|
|
357
|
+
nk_dot_f16x8_state_serial_t const *state_a, nk_dot_f16x8_state_serial_t const *state_b, //
|
|
358
|
+
nk_dot_f16x8_state_serial_t const *state_c, nk_dot_f16x8_state_serial_t const *state_d, //
|
|
359
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
360
|
+
nk_unused_(total_dimensions);
|
|
361
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
362
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
363
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
364
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
typedef struct nk_dot_bf16x8_state_serial_t {
|
|
368
|
+
nk_f32_t sums[4];
|
|
369
|
+
} nk_dot_bf16x8_state_serial_t;
|
|
370
|
+
|
|
371
|
+
NK_INTERNAL void nk_dot_bf16x8_init_serial(nk_dot_bf16x8_state_serial_t *state) {
|
|
372
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
NK_INTERNAL void nk_dot_bf16x8_update_serial(nk_dot_bf16x8_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
376
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
377
|
+
nk_unused_(depth_offset);
|
|
378
|
+
nk_unused_(active_dimensions);
|
|
379
|
+
nk_f32_t sum0 = state->sums[0], sum1 = state->sums[1], sum2 = state->sums[2], sum3 = state->sums[3];
|
|
380
|
+
for (nk_size_t i = 0; i < 8; i += 4) {
|
|
381
|
+
nk_f32_t a0, a1, a2, a3, b0, b1, b2, b3;
|
|
382
|
+
nk_bf16_to_f32_serial(a.bf16s + i + 0, &a0), nk_bf16_to_f32_serial(a.bf16s + i + 1, &a1);
|
|
383
|
+
nk_bf16_to_f32_serial(a.bf16s + i + 2, &a2), nk_bf16_to_f32_serial(a.bf16s + i + 3, &a3);
|
|
384
|
+
nk_bf16_to_f32_serial(b.bf16s + i + 0, &b0), nk_bf16_to_f32_serial(b.bf16s + i + 1, &b1);
|
|
385
|
+
nk_bf16_to_f32_serial(b.bf16s + i + 2, &b2), nk_bf16_to_f32_serial(b.bf16s + i + 3, &b3);
|
|
386
|
+
sum0 += a0 * b0, sum1 += a1 * b1, sum2 += a2 * b2, sum3 += a3 * b3;
|
|
387
|
+
}
|
|
388
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
NK_INTERNAL void nk_dot_bf16x8_finalize_serial( //
|
|
392
|
+
nk_dot_bf16x8_state_serial_t const *state_a, nk_dot_bf16x8_state_serial_t const *state_b, //
|
|
393
|
+
nk_dot_bf16x8_state_serial_t const *state_c, nk_dot_bf16x8_state_serial_t const *state_d, //
|
|
394
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
395
|
+
nk_unused_(total_dimensions);
|
|
396
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
397
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
398
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
399
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
#pragma endregion - Smaller Floats
|
|
403
|
+
|
|
404
|
+
#pragma region - Small Integers
|
|
405
|
+
|
|
406
|
+
typedef struct nk_dot_i8x16_state_serial_t {
|
|
407
|
+
nk_i64_t sums[2];
|
|
408
|
+
} nk_dot_i8x16_state_serial_t;
|
|
409
|
+
|
|
410
|
+
NK_INTERNAL void nk_dot_i8x16_init_serial(nk_dot_i8x16_state_serial_t *state) {
|
|
411
|
+
state->sums[0] = 0, state->sums[1] = 0;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
NK_INTERNAL void nk_dot_i8x16_update_serial(nk_dot_i8x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
415
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
416
|
+
nk_unused_(depth_offset);
|
|
417
|
+
nk_unused_(active_dimensions);
|
|
418
|
+
nk_i64_t sum0 = state->sums[0];
|
|
419
|
+
nk_i64_t sum1 = state->sums[1];
|
|
420
|
+
sum0 += (nk_i16_t)a.i8s[0] * (nk_i16_t)b.i8s[0], sum1 += (nk_i16_t)a.i8s[1] * (nk_i16_t)b.i8s[1];
|
|
421
|
+
sum0 += (nk_i16_t)a.i8s[2] * (nk_i16_t)b.i8s[2], sum1 += (nk_i16_t)a.i8s[3] * (nk_i16_t)b.i8s[3];
|
|
422
|
+
sum0 += (nk_i16_t)a.i8s[4] * (nk_i16_t)b.i8s[4], sum1 += (nk_i16_t)a.i8s[5] * (nk_i16_t)b.i8s[5];
|
|
423
|
+
sum0 += (nk_i16_t)a.i8s[6] * (nk_i16_t)b.i8s[6], sum1 += (nk_i16_t)a.i8s[7] * (nk_i16_t)b.i8s[7];
|
|
424
|
+
sum0 += (nk_i16_t)a.i8s[8] * (nk_i16_t)b.i8s[8], sum1 += (nk_i16_t)a.i8s[9] * (nk_i16_t)b.i8s[9];
|
|
425
|
+
sum0 += (nk_i16_t)a.i8s[10] * (nk_i16_t)b.i8s[10], sum1 += (nk_i16_t)a.i8s[11] * (nk_i16_t)b.i8s[11];
|
|
426
|
+
sum0 += (nk_i16_t)a.i8s[12] * (nk_i16_t)b.i8s[12], sum1 += (nk_i16_t)a.i8s[13] * (nk_i16_t)b.i8s[13];
|
|
427
|
+
sum0 += (nk_i16_t)a.i8s[14] * (nk_i16_t)b.i8s[14], sum1 += (nk_i16_t)a.i8s[15] * (nk_i16_t)b.i8s[15];
|
|
428
|
+
state->sums[0] = sum0, state->sums[1] = sum1;
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
NK_INTERNAL void nk_dot_i8x16_finalize_serial( //
|
|
432
|
+
nk_dot_i8x16_state_serial_t const *state_a, nk_dot_i8x16_state_serial_t const *state_b, //
|
|
433
|
+
nk_dot_i8x16_state_serial_t const *state_c, nk_dot_i8x16_state_serial_t const *state_d, //
|
|
434
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
435
|
+
nk_unused_(total_dimensions);
|
|
436
|
+
result->i32s[0] = (nk_i32_t)(state_a->sums[0] + state_a->sums[1]);
|
|
437
|
+
result->i32s[1] = (nk_i32_t)(state_b->sums[0] + state_b->sums[1]);
|
|
438
|
+
result->i32s[2] = (nk_i32_t)(state_c->sums[0] + state_c->sums[1]);
|
|
439
|
+
result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
typedef struct nk_dot_u8x16_state_serial_t {
|
|
443
|
+
nk_u64_t sums[2];
|
|
444
|
+
} nk_dot_u8x16_state_serial_t;
|
|
445
|
+
|
|
446
|
+
NK_INTERNAL void nk_dot_u8x16_init_serial(nk_dot_u8x16_state_serial_t *state) {
|
|
447
|
+
state->sums[0] = 0, state->sums[1] = 0;
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
NK_INTERNAL void nk_dot_u8x16_update_serial(nk_dot_u8x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
451
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
452
|
+
nk_unused_(depth_offset);
|
|
453
|
+
nk_unused_(active_dimensions);
|
|
454
|
+
nk_u64_t sum0 = state->sums[0];
|
|
455
|
+
nk_u64_t sum1 = state->sums[1];
|
|
456
|
+
|
|
457
|
+
sum0 += (nk_u16_t)a.u8s[0] * (nk_u16_t)b.u8s[0], sum1 += (nk_u16_t)a.u8s[1] * (nk_u16_t)b.u8s[1];
|
|
458
|
+
sum0 += (nk_u16_t)a.u8s[2] * (nk_u16_t)b.u8s[2], sum1 += (nk_u16_t)a.u8s[3] * (nk_u16_t)b.u8s[3];
|
|
459
|
+
sum0 += (nk_u16_t)a.u8s[4] * (nk_u16_t)b.u8s[4], sum1 += (nk_u16_t)a.u8s[5] * (nk_u16_t)b.u8s[5];
|
|
460
|
+
sum0 += (nk_u16_t)a.u8s[6] * (nk_u16_t)b.u8s[6], sum1 += (nk_u16_t)a.u8s[7] * (nk_u16_t)b.u8s[7];
|
|
461
|
+
sum0 += (nk_u16_t)a.u8s[8] * (nk_u16_t)b.u8s[8], sum1 += (nk_u16_t)a.u8s[9] * (nk_u16_t)b.u8s[9];
|
|
462
|
+
sum0 += (nk_u16_t)a.u8s[10] * (nk_u16_t)b.u8s[10], sum1 += (nk_u16_t)a.u8s[11] * (nk_u16_t)b.u8s[11];
|
|
463
|
+
sum0 += (nk_u16_t)a.u8s[12] * (nk_u16_t)b.u8s[12], sum1 += (nk_u16_t)a.u8s[13] * (nk_u16_t)b.u8s[13];
|
|
464
|
+
sum0 += (nk_u16_t)a.u8s[14] * (nk_u16_t)b.u8s[14], sum1 += (nk_u16_t)a.u8s[15] * (nk_u16_t)b.u8s[15];
|
|
465
|
+
state->sums[0] = sum0, state->sums[1] = sum1;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
NK_INTERNAL void nk_dot_u8x16_finalize_serial( //
|
|
469
|
+
nk_dot_u8x16_state_serial_t const *state_a, nk_dot_u8x16_state_serial_t const *state_b, //
|
|
470
|
+
nk_dot_u8x16_state_serial_t const *state_c, nk_dot_u8x16_state_serial_t const *state_d, //
|
|
471
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
472
|
+
nk_unused_(total_dimensions);
|
|
473
|
+
result->u32s[0] = (nk_u32_t)(state_a->sums[0] + state_a->sums[1]);
|
|
474
|
+
result->u32s[1] = (nk_u32_t)(state_b->sums[0] + state_b->sums[1]);
|
|
475
|
+
result->u32s[2] = (nk_u32_t)(state_c->sums[0] + state_c->sums[1]);
|
|
476
|
+
result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
#pragma endregion - Small Integers
|
|
480
|
+
|
|
481
|
+
#pragma region - Smaller Floats
|
|
482
|
+
|
|
483
|
+
typedef struct nk_dot_e4m3x16_state_serial_t {
|
|
484
|
+
nk_f32_t sums[4];
|
|
485
|
+
} nk_dot_e4m3x16_state_serial_t;
|
|
486
|
+
|
|
487
|
+
NK_INTERNAL void nk_dot_e4m3x16_init_serial(nk_dot_e4m3x16_state_serial_t *state) {
|
|
488
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
NK_INTERNAL void nk_dot_e4m3x16_update_serial(nk_dot_e4m3x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
492
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
493
|
+
nk_unused_(depth_offset);
|
|
494
|
+
nk_unused_(active_dimensions);
|
|
495
|
+
nk_f32_t sum0 = state->sums[0];
|
|
496
|
+
nk_f32_t sum1 = state->sums[1];
|
|
497
|
+
nk_f32_t sum2 = state->sums[2];
|
|
498
|
+
nk_f32_t sum3 = state->sums[3];
|
|
499
|
+
nk_f32_t ai0, ai1, ai2, ai3;
|
|
500
|
+
nk_f32_t bi0, bi1, bi2, bi3;
|
|
501
|
+
for (nk_size_t i = 0; i != 16; i += 4) {
|
|
502
|
+
nk_e4m3_to_f32_serial(a.e4m3s + i, &ai0), nk_e4m3_to_f32_serial(b.e4m3s + i, &bi0);
|
|
503
|
+
nk_e4m3_to_f32_serial(a.e4m3s + i + 1, &ai1), nk_e4m3_to_f32_serial(b.e4m3s + i + 1, &bi1);
|
|
504
|
+
nk_e4m3_to_f32_serial(a.e4m3s + i + 2, &ai2), nk_e4m3_to_f32_serial(b.e4m3s + i + 2, &bi2);
|
|
505
|
+
nk_e4m3_to_f32_serial(a.e4m3s + i + 3, &ai3), nk_e4m3_to_f32_serial(b.e4m3s + i + 3, &bi3);
|
|
506
|
+
sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
NK_INTERNAL void nk_dot_e4m3x16_finalize_serial( //
|
|
513
|
+
nk_dot_e4m3x16_state_serial_t const *state_a, nk_dot_e4m3x16_state_serial_t const *state_b, //
|
|
514
|
+
nk_dot_e4m3x16_state_serial_t const *state_c, nk_dot_e4m3x16_state_serial_t const *state_d, //
|
|
515
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
516
|
+
nk_unused_(total_dimensions);
|
|
517
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
518
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
519
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
520
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
typedef struct nk_dot_e5m2x16_state_serial_t {
|
|
524
|
+
nk_f32_t sums[4];
|
|
525
|
+
} nk_dot_e5m2x16_state_serial_t;
|
|
526
|
+
|
|
527
|
+
NK_INTERNAL void nk_dot_e5m2x16_init_serial(nk_dot_e5m2x16_state_serial_t *state) {
|
|
528
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
NK_INTERNAL void nk_dot_e5m2x16_update_serial(nk_dot_e5m2x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
532
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
533
|
+
nk_unused_(depth_offset);
|
|
534
|
+
nk_unused_(active_dimensions);
|
|
535
|
+
nk_f32_t sum0 = state->sums[0];
|
|
536
|
+
nk_f32_t sum1 = state->sums[1];
|
|
537
|
+
nk_f32_t sum2 = state->sums[2];
|
|
538
|
+
nk_f32_t sum3 = state->sums[3];
|
|
539
|
+
nk_f32_t ai0, ai1, ai2, ai3;
|
|
540
|
+
nk_f32_t bi0, bi1, bi2, bi3;
|
|
541
|
+
for (nk_size_t i = 0; i != 16; i += 4) {
|
|
542
|
+
nk_e5m2_to_f32_serial(a.e5m2s + i, &ai0), nk_e5m2_to_f32_serial(b.e5m2s + i, &bi0);
|
|
543
|
+
nk_e5m2_to_f32_serial(a.e5m2s + i + 1, &ai1), nk_e5m2_to_f32_serial(b.e5m2s + i + 1, &bi1);
|
|
544
|
+
nk_e5m2_to_f32_serial(a.e5m2s + i + 2, &ai2), nk_e5m2_to_f32_serial(b.e5m2s + i + 2, &bi2);
|
|
545
|
+
nk_e5m2_to_f32_serial(a.e5m2s + i + 3, &ai3), nk_e5m2_to_f32_serial(b.e5m2s + i + 3, &bi3);
|
|
546
|
+
sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
NK_INTERNAL void nk_dot_e5m2x16_finalize_serial( //
|
|
553
|
+
nk_dot_e5m2x16_state_serial_t const *state_a, nk_dot_e5m2x16_state_serial_t const *state_b, //
|
|
554
|
+
nk_dot_e5m2x16_state_serial_t const *state_c, nk_dot_e5m2x16_state_serial_t const *state_d, //
|
|
555
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
556
|
+
nk_unused_(total_dimensions);
|
|
557
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
558
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
559
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
560
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
typedef struct nk_dot_e2m3x16_state_serial_t {
|
|
564
|
+
nk_f32_t sums[4];
|
|
565
|
+
} nk_dot_e2m3x16_state_serial_t;
|
|
566
|
+
|
|
567
|
+
NK_INTERNAL void nk_dot_e2m3x16_init_serial(nk_dot_e2m3x16_state_serial_t *state) {
|
|
568
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
NK_INTERNAL void nk_dot_e2m3x16_update_serial(nk_dot_e2m3x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
572
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
573
|
+
nk_unused_(depth_offset);
|
|
574
|
+
nk_unused_(active_dimensions);
|
|
575
|
+
nk_f32_t sum0 = state->sums[0];
|
|
576
|
+
nk_f32_t sum1 = state->sums[1];
|
|
577
|
+
nk_f32_t sum2 = state->sums[2];
|
|
578
|
+
nk_f32_t sum3 = state->sums[3];
|
|
579
|
+
nk_f32_t ai0, ai1, ai2, ai3;
|
|
580
|
+
nk_f32_t bi0, bi1, bi2, bi3;
|
|
581
|
+
for (nk_size_t i = 0; i != 16; i += 4) {
|
|
582
|
+
nk_e2m3_to_f32_serial(a.e2m3s + i, &ai0), nk_e2m3_to_f32_serial(b.e2m3s + i, &bi0);
|
|
583
|
+
nk_e2m3_to_f32_serial(a.e2m3s + i + 1, &ai1), nk_e2m3_to_f32_serial(b.e2m3s + i + 1, &bi1);
|
|
584
|
+
nk_e2m3_to_f32_serial(a.e2m3s + i + 2, &ai2), nk_e2m3_to_f32_serial(b.e2m3s + i + 2, &bi2);
|
|
585
|
+
nk_e2m3_to_f32_serial(a.e2m3s + i + 3, &ai3), nk_e2m3_to_f32_serial(b.e2m3s + i + 3, &bi3);
|
|
586
|
+
sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
NK_INTERNAL void nk_dot_e2m3x16_finalize_serial( //
|
|
593
|
+
nk_dot_e2m3x16_state_serial_t const *state_a, nk_dot_e2m3x16_state_serial_t const *state_b, //
|
|
594
|
+
nk_dot_e2m3x16_state_serial_t const *state_c, nk_dot_e2m3x16_state_serial_t const *state_d, //
|
|
595
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
596
|
+
nk_unused_(total_dimensions);
|
|
597
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
598
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
599
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
600
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
typedef struct nk_dot_e3m2x16_state_serial_t {
|
|
604
|
+
nk_f32_t sums[4];
|
|
605
|
+
} nk_dot_e3m2x16_state_serial_t;
|
|
606
|
+
|
|
607
|
+
NK_INTERNAL void nk_dot_e3m2x16_init_serial(nk_dot_e3m2x16_state_serial_t *state) {
|
|
608
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
NK_INTERNAL void nk_dot_e3m2x16_update_serial(nk_dot_e3m2x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
612
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
613
|
+
nk_unused_(depth_offset);
|
|
614
|
+
nk_unused_(active_dimensions);
|
|
615
|
+
nk_f32_t sum0 = state->sums[0];
|
|
616
|
+
nk_f32_t sum1 = state->sums[1];
|
|
617
|
+
nk_f32_t sum2 = state->sums[2];
|
|
618
|
+
nk_f32_t sum3 = state->sums[3];
|
|
619
|
+
nk_f32_t ai0, ai1, ai2, ai3;
|
|
620
|
+
nk_f32_t bi0, bi1, bi2, bi3;
|
|
621
|
+
for (nk_size_t i = 0; i != 16; i += 4) {
|
|
622
|
+
nk_e3m2_to_f32_serial(a.e3m2s + i, &ai0), nk_e3m2_to_f32_serial(b.e3m2s + i, &bi0);
|
|
623
|
+
nk_e3m2_to_f32_serial(a.e3m2s + i + 1, &ai1), nk_e3m2_to_f32_serial(b.e3m2s + i + 1, &bi1);
|
|
624
|
+
nk_e3m2_to_f32_serial(a.e3m2s + i + 2, &ai2), nk_e3m2_to_f32_serial(b.e3m2s + i + 2, &bi2);
|
|
625
|
+
nk_e3m2_to_f32_serial(a.e3m2s + i + 3, &ai3), nk_e3m2_to_f32_serial(b.e3m2s + i + 3, &bi3);
|
|
626
|
+
sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
NK_INTERNAL void nk_dot_e3m2x16_finalize_serial( //
|
|
633
|
+
nk_dot_e3m2x16_state_serial_t const *state_a, nk_dot_e3m2x16_state_serial_t const *state_b, //
|
|
634
|
+
nk_dot_e3m2x16_state_serial_t const *state_c, nk_dot_e3m2x16_state_serial_t const *state_d, //
|
|
635
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
636
|
+
nk_unused_(total_dimensions);
|
|
637
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
638
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
639
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
640
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
#pragma endregion - Smaller Floats
|
|
644
|
+
|
|
645
|
+
#pragma region - Small Integers
|
|
646
|
+
|
|
647
|
+
// U4x2 state: processes 16 nibbles (8 bytes = 64 bits) per update
|
|
648
|
+
typedef struct nk_dot_u4x16_state_serial_t {
|
|
649
|
+
nk_u64_t sums[2]; // sums[0]: low nibbles, sums[1]: high nibbles
|
|
650
|
+
} nk_dot_u4x16_state_serial_t;
|
|
651
|
+
|
|
652
|
+
NK_INTERNAL void nk_dot_u4x16_init_serial(nk_dot_u4x16_state_serial_t *state) {
|
|
653
|
+
state->sums[0] = 0, state->sums[1] = 0;
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
NK_INTERNAL void nk_dot_u4x16_update_serial(nk_dot_u4x16_state_serial_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
657
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
658
|
+
nk_unused_(depth_offset);
|
|
659
|
+
nk_unused_(active_dimensions);
|
|
660
|
+
// Process 8 bytes (16 nibbles total) using SWAR
|
|
661
|
+
// Separate accumulators for low and high nibbles
|
|
662
|
+
nk_u64_t sum_low = state->sums[0];
|
|
663
|
+
nk_u64_t sum_high = state->sums[1];
|
|
664
|
+
|
|
665
|
+
// Process all 8 bytes, extracting and multiplying nibbles
|
|
666
|
+
for (nk_size_t i = 0; i < 8; i++) {
|
|
667
|
+
nk_u8_t a_byte = a.u8s[i];
|
|
668
|
+
nk_u8_t b_byte = b.u8s[i];
|
|
669
|
+
|
|
670
|
+
// Extract low and high nibbles using SWAR masks
|
|
671
|
+
nk_u8_t a_low = a_byte & 0x0F;
|
|
672
|
+
nk_u8_t b_low = b_byte & 0x0F;
|
|
673
|
+
nk_u8_t a_high = (a_byte >> 4) & 0x0F;
|
|
674
|
+
nk_u8_t b_high = (b_byte >> 4) & 0x0F;
|
|
675
|
+
|
|
676
|
+
// Accumulate products into separate accumulators
|
|
677
|
+
sum_low += (nk_u32_t)a_low * (nk_u32_t)b_low;
|
|
678
|
+
sum_high += (nk_u32_t)a_high * (nk_u32_t)b_high;
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
state->sums[0] = sum_low, state->sums[1] = sum_high;
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
NK_INTERNAL void nk_dot_u4x16_finalize_serial(nk_dot_u4x16_state_serial_t const *state_a,
|
|
685
|
+
nk_dot_u4x16_state_serial_t const *state_b,
|
|
686
|
+
nk_dot_u4x16_state_serial_t const *state_c,
|
|
687
|
+
nk_dot_u4x16_state_serial_t const *state_d, nk_size_t total_dimensions,
|
|
688
|
+
nk_b128_vec_t *result) {
|
|
689
|
+
nk_unused_(total_dimensions);
|
|
690
|
+
result->u32s[0] = (nk_u32_t)(state_a->sums[0] + state_a->sums[1]);
|
|
691
|
+
result->u32s[1] = (nk_u32_t)(state_b->sums[0] + state_b->sums[1]);
|
|
692
|
+
result->u32s[2] = (nk_u32_t)(state_c->sums[0] + state_c->sums[1]);
|
|
693
|
+
result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
NK_INTERNAL void nk_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
697
|
+
nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, 16);
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
NK_INTERNAL void nk_partial_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
701
|
+
nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, n);
|
|
702
|
+
for (nk_size_t i = n; i < 16; ++i) dst->i8s[i] = 0;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
NK_INTERNAL void nk_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
706
|
+
nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, 16);
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
NK_INTERNAL void nk_partial_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
710
|
+
nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, n);
|
|
711
|
+
for (nk_size_t i = n; i < 16; ++i) dst->u8s[i] = 0;
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
typedef struct nk_dot_i4x16_state_serial_t {
|
|
715
|
+
nk_i64_t sums[2]; // sums[0]: low nibbles, sums[1]: high nibbles
|
|
716
|
+
} nk_dot_i4x16_state_serial_t;
|
|
717
|
+
|
|
718
|
+
NK_INTERNAL void nk_dot_i4x16_init_serial(nk_dot_i4x16_state_serial_t *state) {
|
|
719
|
+
state->sums[0] = 0, state->sums[1] = 0;
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
NK_INTERNAL void nk_dot_i4x16_update_serial(nk_dot_i4x16_state_serial_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
723
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
724
|
+
nk_unused_(depth_offset);
|
|
725
|
+
nk_unused_(active_dimensions);
|
|
726
|
+
// Process 8 bytes (16 nibbles total) using SWAR with sign extension
|
|
727
|
+
// Separate accumulators for low and high nibbles
|
|
728
|
+
nk_i64_t sum_low = state->sums[0];
|
|
729
|
+
nk_i64_t sum_high = state->sums[1];
|
|
730
|
+
|
|
731
|
+
// Process all 8 bytes, extracting and multiplying signed nibbles
|
|
732
|
+
for (nk_size_t i = 0; i < 8; i++) {
|
|
733
|
+
nk_u8_t a_byte = a.u8s[i];
|
|
734
|
+
nk_u8_t b_byte = b.u8s[i];
|
|
735
|
+
|
|
736
|
+
// Extract nibbles and sign extend: (nibble ^ 8) - 8 maps [0,15] → [-8,7]
|
|
737
|
+
nk_i8_t a_low = (nk_i8_t)(((a_byte & 0x0F) ^ 8) - 8);
|
|
738
|
+
nk_i8_t b_low = (nk_i8_t)(((b_byte & 0x0F) ^ 8) - 8);
|
|
739
|
+
nk_i8_t a_high = (nk_i8_t)((((a_byte >> 4) & 0x0F) ^ 8) - 8);
|
|
740
|
+
nk_i8_t b_high = (nk_i8_t)((((b_byte >> 4) & 0x0F) ^ 8) - 8);
|
|
741
|
+
|
|
742
|
+
// Accumulate products into separate accumulators
|
|
743
|
+
sum_low += (nk_i32_t)a_low * (nk_i32_t)b_low;
|
|
744
|
+
sum_high += (nk_i32_t)a_high * (nk_i32_t)b_high;
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
state->sums[0] = sum_low, state->sums[1] = sum_high;
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
NK_INTERNAL void nk_dot_i4x16_finalize_serial(nk_dot_i4x16_state_serial_t const *state_a,
|
|
751
|
+
nk_dot_i4x16_state_serial_t const *state_b,
|
|
752
|
+
nk_dot_i4x16_state_serial_t const *state_c,
|
|
753
|
+
nk_dot_i4x16_state_serial_t const *state_d, nk_size_t total_dimensions,
|
|
754
|
+
nk_b128_vec_t *result) {
|
|
755
|
+
nk_unused_(total_dimensions);
|
|
756
|
+
result->i32s[0] = (nk_i32_t)(state_a->sums[0] + state_a->sums[1]);
|
|
757
|
+
result->i32s[1] = (nk_i32_t)(state_b->sums[0] + state_b->sums[1]);
|
|
758
|
+
result->i32s[2] = (nk_i32_t)(state_c->sums[0] + state_c->sums[1]);
|
|
759
|
+
result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
#pragma endregion - Small Integers
|
|
763
|
+
|
|
764
|
+
#pragma region - Binary
|
|
765
|
+
|
|
766
|
+
NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
767
|
+
nk_u32_t dot = 0;
|
|
768
|
+
nk_size_t bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
769
|
+
for (nk_size_t i = 0; i < bytes; ++i) dot += nk_u1x8_popcount_(((nk_u8_t const *)a)[i] & ((nk_u8_t const *)b)[i]);
|
|
770
|
+
*result = dot;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
typedef struct nk_dot_u1x128_state_serial_t {
|
|
774
|
+
nk_u32_t dot_count;
|
|
775
|
+
} nk_dot_u1x128_state_serial_t;
|
|
776
|
+
|
|
777
|
+
NK_INTERNAL void nk_dot_u1x128_init_serial(nk_dot_u1x128_state_serial_t *state) { state->dot_count = 0; }
|
|
778
|
+
|
|
779
|
+
NK_INTERNAL void nk_dot_u1x128_update_serial(nk_dot_u1x128_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
780
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
781
|
+
nk_unused_(depth_offset);
|
|
782
|
+
nk_unused_(active_dimensions);
|
|
783
|
+
nk_u64_t and_low = a.u64s[0] & b.u64s[0];
|
|
784
|
+
nk_u64_t and_high = a.u64s[1] & b.u64s[1];
|
|
785
|
+
state->dot_count += (nk_u32_t)nk_u64_popcount_(and_low);
|
|
786
|
+
state->dot_count += (nk_u32_t)nk_u64_popcount_(and_high);
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t const *state_a,
|
|
790
|
+
nk_dot_u1x128_state_serial_t const *state_b,
|
|
791
|
+
nk_dot_u1x128_state_serial_t const *state_c,
|
|
792
|
+
nk_dot_u1x128_state_serial_t const *state_d, nk_size_t total_dimensions,
|
|
793
|
+
nk_b128_vec_t *result) {
|
|
794
|
+
nk_unused_(total_dimensions);
|
|
795
|
+
result->u32s[0] = state_a->dot_count;
|
|
796
|
+
result->u32s[1] = state_b->dot_count;
|
|
797
|
+
result->u32s[2] = state_c->dot_count;
|
|
798
|
+
result->u32s[3] = state_d->dot_count;
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
#pragma endregion - Binary
|
|
802
|
+
|
|
803
|
+
/**
|
|
804
|
+
* Serial fallback sum helpers for progressive element-sum accumulation.
|
|
805
|
+
* Used by the compensated symmetric GEMM macro to piggyback sum computation
|
|
806
|
+
* on the depth loop's already-loaded vectors, avoiding a separate sum pass.
|
|
807
|
+
*/
|
|
808
|
+
|
|
809
|
+
#pragma region - Stateful Element Sum Helpers (for compensated GEMM)
|
|
810
|
+
|
|
811
|
+
/* i4x32: Haswell i4 (nk_b128_vec_t containing 32 nibbles in 16 bytes) */
|
|
812
|
+
typedef struct nk_sum_i4x32_state_serial_t {
|
|
813
|
+
nk_i64_t sum;
|
|
814
|
+
} nk_sum_i4x32_state_serial_t;
|
|
815
|
+
|
|
816
|
+
NK_INTERNAL void nk_sum_i4x32_init_serial(nk_sum_i4x32_state_serial_t *state) { state->sum = 0; }
|
|
817
|
+
|
|
818
|
+
NK_INTERNAL void nk_sum_i4x32_update_serial(nk_sum_i4x32_state_serial_t *state, nk_b128_vec_t v) {
|
|
819
|
+
nk_u8_t const *d = (nk_u8_t const *)&v;
|
|
820
|
+
for (int i = 0; i < 16; i++) {
|
|
821
|
+
nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8; /* sign-extend low nibble */
|
|
822
|
+
nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8; /* sign-extend high nibble */
|
|
823
|
+
state->sum += low + high;
|
|
824
|
+
}
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t const *state, nk_size_t count) {
|
|
828
|
+
nk_unused_(count);
|
|
829
|
+
return (nk_i32_t)state->sum;
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
#pragma endregion - Stateful Element Sum Helpers
|
|
833
|
+
|
|
834
|
+
#if defined(__cplusplus)
|
|
835
|
+
} // extern "C"
|
|
836
|
+
#endif
|
|
837
|
+
|
|
838
|
+
#endif // NK_DOT_SERIAL_H
|