numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1258 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for WASM.
|
|
3
|
+
* @file include/numkong/dot/v128relaxed.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 31, 2026
|
|
6
|
+
*
|
|
7
|
+
* Requires Emscripten 3.1.27+ with `-msimd128 -mrelaxed-simd` flags.
|
|
8
|
+
*
|
|
9
|
+
* Key optimizations:
|
|
10
|
+
* - Uses relaxed FMA (f32x4.relaxed_madd, f64x2.relaxed_madd) for 2x throughput
|
|
11
|
+
* - Smart i8/u8 dot products using algebraic decomposition + correction terms
|
|
12
|
+
* - F64 upcasting variant for improved numerical precision (NEON-style)
|
|
13
|
+
*
|
|
14
|
+
* Smart i8 optimization:
|
|
15
|
+
* Decompose: b = b_7bit - 128 × signbit
|
|
16
|
+
* Therefore: a·b = a·b_7bit - 128 × sum(a[i] where b[i] < 0)
|
|
17
|
+
* Uses fast relaxed_dot_i8x16_i7x16 + SAD-like correction
|
|
18
|
+
*
|
|
19
|
+
* Smart u8 optimization:
|
|
20
|
+
* Decompose: b = b_7bit + 128 × highbit
|
|
21
|
+
* Therefore: a·b = a·b_7bit + 128 × sum(a[i] where b[i] >= 128)
|
|
22
|
+
* Simpler than i8 (positive correction, can use shift instead of mul)
|
|
23
|
+
*/
|
|
24
|
+
|
|
25
|
+
#ifndef NK_DOT_V128RELAXED_H
|
|
26
|
+
#define NK_DOT_V128RELAXED_H
|
|
27
|
+
|
|
28
|
+
#if NK_TARGET_V128RELAXED
|
|
29
|
+
|
|
30
|
+
#include "numkong/types.h"
|
|
31
|
+
#include "numkong/reduce/v128relaxed.h"
|
|
32
|
+
#include "numkong/cast/serial.h"
|
|
33
|
+
#include "numkong/cast/v128relaxed.h"
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#if defined(__clang__)
|
|
40
|
+
#pragma clang attribute push(__attribute__((target("relaxed-simd"))), apply_to = function)
|
|
41
|
+
#endif
|
|
42
|
+
|
|
43
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_v128relaxed_(v128_t sum_f64x2, v128_t compensation_f64x2) {
|
|
44
|
+
v128_t tentative_sum_f64x2 = wasm_f64x2_add(sum_f64x2, compensation_f64x2);
|
|
45
|
+
v128_t virtual_addend_f64x2 = wasm_f64x2_sub(tentative_sum_f64x2, sum_f64x2);
|
|
46
|
+
v128_t rounding_error_f64x2 = wasm_f64x2_add(
|
|
47
|
+
wasm_f64x2_sub(sum_f64x2, wasm_f64x2_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
48
|
+
wasm_f64x2_sub(compensation_f64x2, virtual_addend_f64x2));
|
|
49
|
+
nk_f64_t lower_sum = wasm_f64x2_extract_lane(tentative_sum_f64x2, 0);
|
|
50
|
+
nk_f64_t upper_sum = wasm_f64x2_extract_lane(tentative_sum_f64x2, 1);
|
|
51
|
+
nk_f64_t lower_error = wasm_f64x2_extract_lane(rounding_error_f64x2, 0);
|
|
52
|
+
nk_f64_t upper_error = wasm_f64x2_extract_lane(rounding_error_f64x2, 1);
|
|
53
|
+
nk_f64_t tentative_sum = lower_sum + upper_sum;
|
|
54
|
+
nk_f64_t virtual_addend = tentative_sum - lower_sum;
|
|
55
|
+
nk_f64_t rounding_error = (lower_sum - (tentative_sum - virtual_addend)) + (upper_sum - virtual_addend);
|
|
56
|
+
return tentative_sum + (lower_error + upper_error + rounding_error);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
NK_PUBLIC void nk_dot_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
60
|
+
v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
61
|
+
nk_f32_t const *a_scalars = a, *b_scalars = b;
|
|
62
|
+
nk_size_t count_scalars = n;
|
|
63
|
+
nk_b64_vec_t a_f32_vec, b_f32_vec;
|
|
64
|
+
|
|
65
|
+
nk_dot_f32_v128relaxed_cycle:
|
|
66
|
+
if (count_scalars < 2) {
|
|
67
|
+
nk_partial_load_b32x2_serial_(a_scalars, &a_f32_vec, count_scalars);
|
|
68
|
+
nk_partial_load_b32x2_serial_(b_scalars, &b_f32_vec, count_scalars);
|
|
69
|
+
count_scalars = 0;
|
|
70
|
+
}
|
|
71
|
+
else {
|
|
72
|
+
nk_load_b64_serial_(a_scalars, &a_f32_vec);
|
|
73
|
+
nk_load_b64_serial_(b_scalars, &b_f32_vec);
|
|
74
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
75
|
+
}
|
|
76
|
+
v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
|
|
77
|
+
v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
|
|
78
|
+
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
79
|
+
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
80
|
+
sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, sum_f64x2);
|
|
81
|
+
if (count_scalars) goto nk_dot_f32_v128relaxed_cycle;
|
|
82
|
+
|
|
83
|
+
*result = nk_reduce_add_f64x2_v128relaxed_(sum_f64x2);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
NK_PUBLIC void nk_dot_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
87
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
88
|
+
nk_f16_t const *a_scalars = a, *b_scalars = b;
|
|
89
|
+
nk_size_t count_scalars = n;
|
|
90
|
+
nk_b64_vec_t a_f16_vec, b_f16_vec;
|
|
91
|
+
|
|
92
|
+
nk_dot_f16_v128relaxed_cycle:
|
|
93
|
+
if (count_scalars < 4) {
|
|
94
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_f16_vec, count_scalars);
|
|
95
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_f16_vec, count_scalars);
|
|
96
|
+
count_scalars = 0;
|
|
97
|
+
}
|
|
98
|
+
else {
|
|
99
|
+
nk_load_b64_serial_(a_scalars, &a_f16_vec);
|
|
100
|
+
nk_load_b64_serial_(b_scalars, &b_f16_vec);
|
|
101
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
102
|
+
}
|
|
103
|
+
nk_b128_vec_t a_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(a_f16_vec);
|
|
104
|
+
nk_b128_vec_t b_f32_vec = nk_f16x4_to_f32x4_v128relaxed_(b_f16_vec);
|
|
105
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, sum_f32x4);
|
|
106
|
+
if (count_scalars) goto nk_dot_f16_v128relaxed_cycle;
|
|
107
|
+
|
|
108
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
NK_PUBLIC void nk_dot_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
112
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
113
|
+
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
114
|
+
nk_size_t count_scalars = n;
|
|
115
|
+
nk_b64_vec_t a_bf16_vec, b_bf16_vec;
|
|
116
|
+
|
|
117
|
+
nk_dot_bf16_v128relaxed_cycle:
|
|
118
|
+
if (count_scalars < 4) {
|
|
119
|
+
nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
120
|
+
nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
121
|
+
count_scalars = 0;
|
|
122
|
+
}
|
|
123
|
+
else {
|
|
124
|
+
nk_load_b64_serial_(a_scalars, &a_bf16_vec);
|
|
125
|
+
nk_load_b64_serial_(b_scalars, &b_bf16_vec);
|
|
126
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
127
|
+
}
|
|
128
|
+
nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
|
|
129
|
+
nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
|
|
130
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, sum_f32x4);
|
|
131
|
+
if (count_scalars) goto nk_dot_bf16_v128relaxed_cycle;
|
|
132
|
+
|
|
133
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
NK_PUBLIC void nk_dot_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
137
|
+
v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
138
|
+
v128_t compensation_f64x2 = wasm_f64x2_splat(0.0);
|
|
139
|
+
nk_f64_t const *a_scalars = a, *b_scalars = b;
|
|
140
|
+
nk_size_t count_scalars = n;
|
|
141
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
142
|
+
|
|
143
|
+
nk_dot_f64_v128relaxed_cycle:
|
|
144
|
+
if (count_scalars < 2) {
|
|
145
|
+
nk_partial_load_b64x2_serial_(a_scalars, &a_vec, count_scalars);
|
|
146
|
+
nk_partial_load_b64x2_serial_(b_scalars, &b_vec, count_scalars);
|
|
147
|
+
count_scalars = 0;
|
|
148
|
+
}
|
|
149
|
+
else {
|
|
150
|
+
nk_load_b128_serial_(a_scalars, &a_vec);
|
|
151
|
+
nk_load_b128_serial_(b_scalars, &b_vec);
|
|
152
|
+
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
153
|
+
}
|
|
154
|
+
v128_t product_f64x2 = wasm_f64x2_mul(a_vec.v128, b_vec.v128);
|
|
155
|
+
v128_t product_error_f64x2 = wasm_f64x2_sub(wasm_f64x2_relaxed_madd(a_vec.v128, b_vec.v128, wasm_f64x2_splat(0.0)),
|
|
156
|
+
product_f64x2);
|
|
157
|
+
v128_t tentative_sum_f64x2 = wasm_f64x2_add(sum_f64x2, product_f64x2);
|
|
158
|
+
v128_t virtual_addend_f64x2 = wasm_f64x2_sub(tentative_sum_f64x2, sum_f64x2);
|
|
159
|
+
v128_t sum_error_f64x2 = wasm_f64x2_add(
|
|
160
|
+
wasm_f64x2_sub(sum_f64x2, wasm_f64x2_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
161
|
+
wasm_f64x2_sub(product_f64x2, virtual_addend_f64x2));
|
|
162
|
+
sum_f64x2 = tentative_sum_f64x2;
|
|
163
|
+
compensation_f64x2 = wasm_f64x2_add(compensation_f64x2, wasm_f64x2_add(sum_error_f64x2, product_error_f64x2));
|
|
164
|
+
if (count_scalars) goto nk_dot_f64_v128relaxed_cycle;
|
|
165
|
+
|
|
166
|
+
*result = nk_dot_stable_sum_f64x2_v128relaxed_(sum_f64x2, compensation_f64x2);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
NK_PUBLIC void nk_dot_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
170
|
+
nk_i64_t sum_total = 0;
|
|
171
|
+
nk_size_t i = 0;
|
|
172
|
+
|
|
173
|
+
// Windowed accumulation loop
|
|
174
|
+
while (i + 16 <= n) {
|
|
175
|
+
v128_t sum_i32x4 = wasm_i32x4_splat(0);
|
|
176
|
+
v128_t correction_i16x8 = wasm_i16x8_splat(0);
|
|
177
|
+
|
|
178
|
+
// Inner loop: accumulate 127 iterations before widening correction
|
|
179
|
+
nk_size_t cycle = 0;
|
|
180
|
+
for (; cycle < 127 && i + 16 <= n; ++cycle, i += 16) {
|
|
181
|
+
v128_t a_i8x16 = wasm_v128_load(a + i);
|
|
182
|
+
v128_t b_i8x16 = wasm_v128_load(b + i);
|
|
183
|
+
|
|
184
|
+
// Extract sign bit: b_neg_mask = (b < 0) ? 0xFF : 0x00
|
|
185
|
+
v128_t b_neg_mask_i8x16 = wasm_i8x16_lt(b_i8x16, wasm_i8x16_splat(0));
|
|
186
|
+
|
|
187
|
+
// b_7bit = b & 0x7F (clears sign bit)
|
|
188
|
+
v128_t b_7bit_i8x16 = wasm_v128_and(b_i8x16, wasm_i8x16_splat(0x7F));
|
|
189
|
+
|
|
190
|
+
// Fast path: a · b_7bit
|
|
191
|
+
sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, b_7bit_i8x16, sum_i32x4);
|
|
192
|
+
|
|
193
|
+
// Accumulate correction in i16 (only ONE extadd per iteration!)
|
|
194
|
+
v128_t a_where_b_neg_i8x16 = wasm_v128_and(a_i8x16, b_neg_mask_i8x16);
|
|
195
|
+
v128_t a_neg_i16x8 = wasm_i16x8_extadd_pairwise_i8x16(a_where_b_neg_i8x16);
|
|
196
|
+
correction_i16x8 = wasm_i16x8_add(correction_i16x8, a_neg_i16x8);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// Widen correction once per window: i16 → i32
|
|
200
|
+
v128_t corr_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(correction_i16x8);
|
|
201
|
+
|
|
202
|
+
// Apply correction: sum -= 128 × correction
|
|
203
|
+
v128_t corr_scaled_i32x4 = wasm_i32x4_mul(corr_i32x4, wasm_i32x4_splat(-128));
|
|
204
|
+
sum_i32x4 = wasm_i32x4_add(sum_i32x4, corr_scaled_i32x4);
|
|
205
|
+
|
|
206
|
+
// Reduce window to scalar
|
|
207
|
+
sum_total += (nk_i32_t)nk_reduce_add_i32x4_v128relaxed_(sum_i32x4);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
// Handle tail elements
|
|
211
|
+
for (; i < n; i++) { sum_total += (nk_i32_t)a[i] * (nk_i32_t)b[i]; }
|
|
212
|
+
|
|
213
|
+
*result = (nk_i32_t)sum_total;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
NK_PUBLIC void nk_dot_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
217
|
+
nk_i64_t biased_sum_total = 0;
|
|
218
|
+
nk_i64_t sum_a_total = 0;
|
|
219
|
+
nk_i64_t sum_b_total = 0;
|
|
220
|
+
nk_size_t i = 0;
|
|
221
|
+
|
|
222
|
+
// Bias u8 [0,255] → i8 [-128,127] before relaxed_dot to avoid the internal i16 saturation.
|
|
223
|
+
while (i + 16 <= n) {
|
|
224
|
+
v128_t biased_dot_i32x4 = wasm_i32x4_splat(0);
|
|
225
|
+
v128_t correction_i16x8 = wasm_i16x8_splat(0);
|
|
226
|
+
v128_t sum_a_u16x8 = wasm_u16x8_splat(0);
|
|
227
|
+
v128_t sum_b_u16x8 = wasm_u16x8_splat(0);
|
|
228
|
+
|
|
229
|
+
// Overflow safety:
|
|
230
|
+
// - correction_i16x8 max lane magnitude is 127 * 128 = 16256 < 32767
|
|
231
|
+
// - sum_a/sum_b max lane is 127 * 510 = 64770 < 65535
|
|
232
|
+
nk_size_t cycle = 0;
|
|
233
|
+
for (; cycle < 127 && i + 16 <= n; ++cycle, i += 16) {
|
|
234
|
+
v128_t a_u8x16 = wasm_v128_load(a + i);
|
|
235
|
+
v128_t b_u8x16 = wasm_v128_load(b + i);
|
|
236
|
+
v128_t a_i8x16 = wasm_v128_xor(a_u8x16, wasm_i8x16_splat((char)0x80));
|
|
237
|
+
v128_t b_i8x16 = wasm_v128_xor(b_u8x16, wasm_i8x16_splat((char)0x80));
|
|
238
|
+
v128_t b_7bit_u8x16 = wasm_v128_and(b_i8x16, wasm_i8x16_splat(0x7F));
|
|
239
|
+
v128_t b_neg_mask_i8x16 = wasm_i8x16_lt(b_i8x16, wasm_i8x16_splat(0));
|
|
240
|
+
|
|
241
|
+
biased_dot_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_i8x16, b_7bit_u8x16, biased_dot_i32x4);
|
|
242
|
+
correction_i16x8 = wasm_i16x8_add(
|
|
243
|
+
correction_i16x8, wasm_i16x8_extadd_pairwise_i8x16(wasm_v128_and(a_i8x16, b_neg_mask_i8x16)));
|
|
244
|
+
sum_a_u16x8 = wasm_i16x8_add(sum_a_u16x8, wasm_u16x8_extadd_pairwise_u8x16(a_u8x16));
|
|
245
|
+
sum_b_u16x8 = wasm_i16x8_add(sum_b_u16x8, wasm_u16x8_extadd_pairwise_u8x16(b_u8x16));
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
v128_t correction_i32x4 = wasm_i32x4_extadd_pairwise_i16x8(correction_i16x8);
|
|
249
|
+
v128_t sum_a_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(sum_a_u16x8);
|
|
250
|
+
v128_t sum_b_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(sum_b_u16x8);
|
|
251
|
+
biased_sum_total += nk_reduce_add_i32x4_v128relaxed_(biased_dot_i32x4) -
|
|
252
|
+
128LL * nk_reduce_add_i32x4_v128relaxed_(correction_i32x4);
|
|
253
|
+
sum_a_total += nk_reduce_add_u32x4_v128relaxed_(sum_a_u32x4);
|
|
254
|
+
sum_b_total += nk_reduce_add_u32x4_v128relaxed_(sum_b_u32x4);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
for (; i < n; i++) {
|
|
258
|
+
nk_i32_t a_biased = (nk_i32_t)a[i] - 128;
|
|
259
|
+
nk_i32_t b_biased = (nk_i32_t)b[i] - 128;
|
|
260
|
+
biased_sum_total += (nk_i64_t)a_biased * b_biased;
|
|
261
|
+
sum_a_total += a[i];
|
|
262
|
+
sum_b_total += b[i];
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
biased_sum_total += 128LL * (sum_a_total + sum_b_total) - (nk_i64_t)n * 16384LL;
|
|
266
|
+
*result = (nk_u32_t)biased_sum_total;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
NK_PUBLIC void nk_dot_e2m3_v128relaxed(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
270
|
+
nk_f32_t *result) {
|
|
271
|
+
// Integer dot product for e2m3 using relaxed SIMD: wasm_i32x4_relaxed_dot_i8x16_i7x16_add.
|
|
272
|
+
// Every e2m3 value × 16 is an exact integer in [-120, +120].
|
|
273
|
+
// The relaxed dot takes i8 × u7 (first signed, second unsigned [0,127]). Our magnitudes [0,120] fit.
|
|
274
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
275
|
+
//
|
|
276
|
+
// 32-entry LUT split into two 16-entry halves for wasm_i8x16_relaxed_swizzle (indexes 0-15).
|
|
277
|
+
v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
|
|
278
|
+
v128_t lut_upper_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
|
|
279
|
+
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
280
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
281
|
+
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
282
|
+
v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
|
|
283
|
+
v128_t sum_i32x4 = wasm_i32x4_splat(0);
|
|
284
|
+
v128_t a_e2m3_u8x16, b_e2m3_u8x16;
|
|
285
|
+
|
|
286
|
+
nk_dot_e2m3_v128relaxed_cycle:
|
|
287
|
+
if (count_scalars < 16) {
|
|
288
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
289
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
290
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
291
|
+
a_e2m3_u8x16 = a_vec.v128;
|
|
292
|
+
b_e2m3_u8x16 = b_vec.v128;
|
|
293
|
+
count_scalars = 0;
|
|
294
|
+
}
|
|
295
|
+
else {
|
|
296
|
+
a_e2m3_u8x16 = wasm_v128_load(a_scalars);
|
|
297
|
+
b_e2m3_u8x16 = wasm_v128_load(b_scalars);
|
|
298
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// Extract 5-bit magnitude indices
|
|
302
|
+
v128_t a_magnitude_u8x16 = wasm_v128_and(a_e2m3_u8x16, magnitude_mask_u8x16);
|
|
303
|
+
v128_t b_magnitude_u8x16 = wasm_v128_and(b_e2m3_u8x16, magnitude_mask_u8x16);
|
|
304
|
+
|
|
305
|
+
// Dual swizzle + bitselect for 32-entry LUT (a)
|
|
306
|
+
v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
|
|
307
|
+
v128_t a_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_shuffle_index_u8x16);
|
|
308
|
+
v128_t a_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_shuffle_index_u8x16);
|
|
309
|
+
v128_t a_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
310
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_upper_u8x16, a_lower_u8x16, a_upper_select_u8x16);
|
|
311
|
+
|
|
312
|
+
// Dual swizzle + bitselect for 32-entry LUT (b)
|
|
313
|
+
v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
|
|
314
|
+
v128_t b_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_shuffle_index_u8x16);
|
|
315
|
+
v128_t b_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_shuffle_index_u8x16);
|
|
316
|
+
v128_t b_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
317
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_upper_u8x16, b_lower_u8x16, b_upper_select_u8x16);
|
|
318
|
+
|
|
319
|
+
// Combined sign: (a ^ b) & 0x20 — nonzero means negative product
|
|
320
|
+
// Apply sign to a (relaxed_dot wants i8 × u7: a_signed, b_unsigned)
|
|
321
|
+
v128_t sign_combined_u8x16 = wasm_v128_and(wasm_v128_xor(a_e2m3_u8x16, b_e2m3_u8x16), sign_mask_u8x16);
|
|
322
|
+
v128_t negate_mask_u8x16 = wasm_i8x16_eq(sign_combined_u8x16, sign_mask_u8x16);
|
|
323
|
+
v128_t a_negated_u8x16 = wasm_i8x16_neg(a_unsigned_u8x16);
|
|
324
|
+
v128_t a_signed_i8x16 = wasm_i8x16_relaxed_laneselect(a_negated_u8x16, a_unsigned_u8x16, negate_mask_u8x16);
|
|
325
|
+
|
|
326
|
+
// relaxed_dot: a_signed[i8] × b_unsigned[u7] → i32 accumulate
|
|
327
|
+
sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_unsigned_u8x16, sum_i32x4);
|
|
328
|
+
|
|
329
|
+
if (count_scalars) goto nk_dot_e2m3_v128relaxed_cycle;
|
|
330
|
+
*result = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(sum_i32x4) / 256.0f;
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
NK_PUBLIC void nk_dot_e3m2_v128relaxed(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
334
|
+
nk_f32_t *result) {
|
|
335
|
+
// Integer dot product for e3m2 using i16 arithmetic with widening multiply.
|
|
336
|
+
// Every e3m2 value × 16 is an exact integer, but magnitudes reach 448, requiring i16.
|
|
337
|
+
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
338
|
+
//
|
|
339
|
+
// 32-entry magnitude LUT split into low-byte halves for dual swizzle lookup.
|
|
340
|
+
// High byte is 0 for indices 0-27 and 1 for indices 28-31, so a simple comparison
|
|
341
|
+
// replaces the high-byte LUT entirely.
|
|
342
|
+
//
|
|
343
|
+
// Low-byte LUT entries (magnitude[i] & 0xFF):
|
|
344
|
+
// [0,1,2,3,4,5,6,7,8,10,12,14,16,20,24,28] lower half
|
|
345
|
+
// [32,40,48,56,64,80,96,112,128,160,192,224,0,64,128,192] upper half
|
|
346
|
+
v128_t lut_lo_lower_u8x16 = wasm_i8x16_const(0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28);
|
|
347
|
+
v128_t lut_lo_upper_u8x16 = wasm_u8x16_const(32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 0, 64, 128, 192);
|
|
348
|
+
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
349
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
350
|
+
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
351
|
+
v128_t hi_threshold_u8x16 = wasm_u8x16_splat(28);
|
|
352
|
+
v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
|
|
353
|
+
v128_t sum_i32x4 = wasm_i32x4_splat(0);
|
|
354
|
+
v128_t a_e3m2_u8x16, b_e3m2_u8x16;
|
|
355
|
+
|
|
356
|
+
nk_dot_e3m2_v128relaxed_cycle:
|
|
357
|
+
if (count_scalars < 16) {
|
|
358
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
359
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
360
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
361
|
+
a_e3m2_u8x16 = a_vec.v128;
|
|
362
|
+
b_e3m2_u8x16 = b_vec.v128;
|
|
363
|
+
count_scalars = 0;
|
|
364
|
+
}
|
|
365
|
+
else {
|
|
366
|
+
a_e3m2_u8x16 = wasm_v128_load(a_scalars);
|
|
367
|
+
b_e3m2_u8x16 = wasm_v128_load(b_scalars);
|
|
368
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
// Extract 5-bit magnitude indices
|
|
372
|
+
v128_t a_magnitude_u8x16 = wasm_v128_and(a_e3m2_u8x16, magnitude_mask_u8x16);
|
|
373
|
+
v128_t b_magnitude_u8x16 = wasm_v128_and(b_e3m2_u8x16, magnitude_mask_u8x16);
|
|
374
|
+
|
|
375
|
+
// Dual swizzle + bitselect for 32-entry low-byte LUT (a)
|
|
376
|
+
v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
|
|
377
|
+
v128_t a_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_lower_u8x16, a_shuffle_index_u8x16);
|
|
378
|
+
v128_t a_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_upper_u8x16, a_shuffle_index_u8x16);
|
|
379
|
+
v128_t a_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
380
|
+
v128_t a_lo_bytes_u8x16 = wasm_i8x16_relaxed_laneselect(a_upper_u8x16, a_lower_u8x16, a_upper_select_u8x16);
|
|
381
|
+
|
|
382
|
+
// High byte is 1 iff magnitude index >= 28 (values 256, 320, 384, 448), else 0
|
|
383
|
+
v128_t a_hi_bytes_u8x16 = wasm_v128_and(wasm_u8x16_ge(a_magnitude_u8x16, hi_threshold_u8x16), wasm_u8x16_splat(1));
|
|
384
|
+
|
|
385
|
+
// Dual swizzle + bitselect for 32-entry low-byte LUT (b)
|
|
386
|
+
v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
|
|
387
|
+
v128_t b_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_lower_u8x16, b_shuffle_index_u8x16);
|
|
388
|
+
v128_t b_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_upper_u8x16, b_shuffle_index_u8x16);
|
|
389
|
+
v128_t b_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
390
|
+
v128_t b_lo_bytes_u8x16 = wasm_i8x16_relaxed_laneselect(b_upper_u8x16, b_lower_u8x16, b_upper_select_u8x16);
|
|
391
|
+
|
|
392
|
+
// High byte is 1 iff magnitude index >= 28
|
|
393
|
+
v128_t b_hi_bytes_u8x16 = wasm_v128_and(wasm_u8x16_ge(b_magnitude_u8x16, hi_threshold_u8x16), wasm_u8x16_splat(1));
|
|
394
|
+
|
|
395
|
+
// Combine low and high bytes into i16 via byte interleave shuffle (little-endian: low byte first)
|
|
396
|
+
v128_t a_unsigned_low_i16x8 = wasm_i8x16_shuffle(a_lo_bytes_u8x16, a_hi_bytes_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
|
|
397
|
+
20, 5, 21, 6, 22, 7, 23);
|
|
398
|
+
v128_t a_unsigned_high_i16x8 = wasm_i8x16_shuffle(a_lo_bytes_u8x16, a_hi_bytes_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
|
|
399
|
+
12, 28, 13, 29, 14, 30, 15, 31);
|
|
400
|
+
v128_t b_unsigned_low_i16x8 = wasm_i8x16_shuffle(b_lo_bytes_u8x16, b_hi_bytes_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
|
|
401
|
+
20, 5, 21, 6, 22, 7, 23);
|
|
402
|
+
v128_t b_unsigned_high_i16x8 = wasm_i8x16_shuffle(b_lo_bytes_u8x16, b_hi_bytes_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
|
|
403
|
+
12, 28, 13, 29, 14, 30, 15, 31);
|
|
404
|
+
|
|
405
|
+
// Combined sign: XOR sign bits, negate only b (saves ~15 ops vs independent negation)
|
|
406
|
+
v128_t sign_combined_u8x16 = wasm_v128_and(wasm_v128_xor(a_e3m2_u8x16, b_e3m2_u8x16), sign_mask_u8x16);
|
|
407
|
+
v128_t negate_mask_u8x16 = wasm_i8x16_eq(sign_combined_u8x16, sign_mask_u8x16);
|
|
408
|
+
v128_t negate_low_i16x8 = wasm_i8x16_shuffle(negate_mask_u8x16, negate_mask_u8x16, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5,
|
|
409
|
+
5, 6, 6, 7, 7);
|
|
410
|
+
v128_t negate_high_i16x8 = wasm_i8x16_shuffle(negate_mask_u8x16, negate_mask_u8x16, 8, 8, 9, 9, 10, 10, 11, 11, 12,
|
|
411
|
+
12, 13, 13, 14, 14, 15, 15);
|
|
412
|
+
b_unsigned_low_i16x8 = wasm_i16x8_relaxed_laneselect(wasm_i16x8_neg(b_unsigned_low_i16x8), b_unsigned_low_i16x8,
|
|
413
|
+
negate_low_i16x8);
|
|
414
|
+
b_unsigned_high_i16x8 = wasm_i16x8_relaxed_laneselect(wasm_i16x8_neg(b_unsigned_high_i16x8), b_unsigned_high_i16x8,
|
|
415
|
+
negate_high_i16x8);
|
|
416
|
+
|
|
417
|
+
// Widening multiply: i16×i16 → i32, accumulate (a is unsigned magnitude, b has combined sign)
|
|
418
|
+
sum_i32x4 = wasm_i32x4_add(sum_i32x4, wasm_i32x4_extmul_low_i16x8(a_unsigned_low_i16x8, b_unsigned_low_i16x8));
|
|
419
|
+
sum_i32x4 = wasm_i32x4_add(sum_i32x4, wasm_i32x4_extmul_high_i16x8(a_unsigned_low_i16x8, b_unsigned_low_i16x8));
|
|
420
|
+
sum_i32x4 = wasm_i32x4_add(sum_i32x4, wasm_i32x4_extmul_low_i16x8(a_unsigned_high_i16x8, b_unsigned_high_i16x8));
|
|
421
|
+
sum_i32x4 = wasm_i32x4_add(sum_i32x4, wasm_i32x4_extmul_high_i16x8(a_unsigned_high_i16x8, b_unsigned_high_i16x8));
|
|
422
|
+
|
|
423
|
+
if (count_scalars) goto nk_dot_e3m2_v128relaxed_cycle;
|
|
424
|
+
*result = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(sum_i32x4) / 256.0f;
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
NK_PUBLIC void nk_dot_u1_v128relaxed(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
428
|
+
nk_u8_t const *a_bytes = (nk_u8_t const *)a;
|
|
429
|
+
nk_u8_t const *b_bytes = (nk_u8_t const *)b;
|
|
430
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
431
|
+
|
|
432
|
+
nk_u32_t dot = 0;
|
|
433
|
+
nk_size_t i = 0;
|
|
434
|
+
|
|
435
|
+
// Windowed accumulation loop
|
|
436
|
+
while (i + 16 <= n_bytes) {
|
|
437
|
+
v128_t popcount_u8x16 = wasm_i8x16_splat(0);
|
|
438
|
+
|
|
439
|
+
// Inner loop: accumulate 31 iterations in u8 before widening
|
|
440
|
+
nk_size_t cycle = 0;
|
|
441
|
+
for (; cycle < 31 && i + 16 <= n_bytes; ++cycle, i += 16) {
|
|
442
|
+
v128_t a_u8x16 = wasm_v128_load(a_bytes + i);
|
|
443
|
+
v128_t b_u8x16 = wasm_v128_load(b_bytes + i);
|
|
444
|
+
|
|
445
|
+
// AND to find shared bits (dot product of binary vectors)
|
|
446
|
+
v128_t and_u8x16 = wasm_v128_and(a_u8x16, b_u8x16);
|
|
447
|
+
|
|
448
|
+
// Popcount each byte
|
|
449
|
+
v128_t popcnt_u8x16 = wasm_i8x16_popcnt(and_u8x16);
|
|
450
|
+
|
|
451
|
+
// Accumulate in u8 (safe: 31 × 8 = 248 < 255)
|
|
452
|
+
popcount_u8x16 = wasm_i8x16_add(popcount_u8x16, popcnt_u8x16);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
// Widen once per window: u8 → u16 → u32
|
|
456
|
+
dot += nk_reduce_add_u8x16_v128relaxed_(popcount_u8x16);
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
// Handle tail bytes
|
|
460
|
+
for (; i < n_bytes; i++) {
|
|
461
|
+
nk_u8_t and_byte = a_bytes[i] & b_bytes[i];
|
|
462
|
+
dot += nk_u1x8_popcount_(and_byte);
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
*result = dot;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
/**
|
|
469
|
+
* Stateful GEMM kernels for batched dot products (4-way parallel accumulation).
|
|
470
|
+
* Used by nk_define_cross_packed_ / nk_define_cross_compensated_packed_ macros.
|
|
471
|
+
*/
|
|
472
|
+
|
|
473
|
+
typedef struct nk_dot_through_f32x4_state_v128relaxed_t_ {
|
|
474
|
+
v128_t sum_f32x4;
|
|
475
|
+
} nk_dot_through_f32x4_state_v128relaxed_t_;
|
|
476
|
+
|
|
477
|
+
NK_INTERNAL void nk_dot_through_f32x4_init_v128relaxed_(nk_dot_through_f32x4_state_v128relaxed_t_ *state) {
|
|
478
|
+
state->sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
NK_INTERNAL void nk_dot_through_f32x4_update_v128relaxed_(nk_dot_through_f32x4_state_v128relaxed_t_ *state,
|
|
482
|
+
nk_b128_vec_t a, nk_b128_vec_t b, nk_size_t depth_offset,
|
|
483
|
+
nk_size_t active_dimensions) {
|
|
484
|
+
nk_unused_(depth_offset);
|
|
485
|
+
nk_unused_(active_dimensions);
|
|
486
|
+
state->sum_f32x4 = wasm_f32x4_relaxed_madd(a.v128, b.v128, state->sum_f32x4);
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
NK_INTERNAL void nk_dot_through_f32x4_finalize_v128relaxed_( //
|
|
490
|
+
nk_dot_through_f32x4_state_v128relaxed_t_ const *state_a, nk_dot_through_f32x4_state_v128relaxed_t_ const *state_b,
|
|
491
|
+
nk_dot_through_f32x4_state_v128relaxed_t_ const *state_c, nk_dot_through_f32x4_state_v128relaxed_t_ const *state_d,
|
|
492
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
493
|
+
nk_unused_(total_dimensions);
|
|
494
|
+
result->f32s[0] = nk_reduce_add_f32x4_v128relaxed_(state_a->sum_f32x4);
|
|
495
|
+
result->f32s[1] = nk_reduce_add_f32x4_v128relaxed_(state_b->sum_f32x4);
|
|
496
|
+
result->f32s[2] = nk_reduce_add_f32x4_v128relaxed_(state_c->sum_f32x4);
|
|
497
|
+
result->f32s[3] = nk_reduce_add_f32x4_v128relaxed_(state_d->sum_f32x4);
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
typedef struct nk_dot_f32x2_state_v128relaxed_t {
|
|
501
|
+
v128_t sum_f64x2;
|
|
502
|
+
} nk_dot_f32x2_state_v128relaxed_t;
|
|
503
|
+
|
|
504
|
+
NK_INTERNAL void nk_dot_f32x2_init_v128relaxed(nk_dot_f32x2_state_v128relaxed_t *state) {
|
|
505
|
+
state->sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
NK_INTERNAL void nk_dot_f32x2_update_v128relaxed(nk_dot_f32x2_state_v128relaxed_t *state, nk_b64_vec_t a,
|
|
509
|
+
nk_b64_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
510
|
+
nk_unused_(depth_offset);
|
|
511
|
+
nk_unused_(active_dimensions);
|
|
512
|
+
v128_t a_f32x2 = wasm_v128_load64_zero(&a.u64);
|
|
513
|
+
v128_t b_f32x2 = wasm_v128_load64_zero(&b.u64);
|
|
514
|
+
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
515
|
+
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
516
|
+
state->sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, state->sum_f64x2);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
NK_INTERNAL void nk_dot_f32x2_finalize_v128relaxed( //
|
|
520
|
+
nk_dot_f32x2_state_v128relaxed_t const *state_a, nk_dot_f32x2_state_v128relaxed_t const *state_b, //
|
|
521
|
+
nk_dot_f32x2_state_v128relaxed_t const *state_c, nk_dot_f32x2_state_v128relaxed_t const *state_d, //
|
|
522
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
523
|
+
nk_unused_(total_dimensions);
|
|
524
|
+
result->f64s[0] = nk_reduce_add_f64x2_v128relaxed_(state_a->sum_f64x2);
|
|
525
|
+
result->f64s[1] = nk_reduce_add_f64x2_v128relaxed_(state_b->sum_f64x2);
|
|
526
|
+
result->f64s[2] = nk_reduce_add_f64x2_v128relaxed_(state_c->sum_f64x2);
|
|
527
|
+
result->f64s[3] = nk_reduce_add_f64x2_v128relaxed_(state_d->sum_f64x2);
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
typedef struct nk_dot_f64x2_state_v128relaxed_t {
|
|
531
|
+
v128_t sum_f64x2;
|
|
532
|
+
v128_t compensation_f64x2;
|
|
533
|
+
} nk_dot_f64x2_state_v128relaxed_t;
|
|
534
|
+
|
|
535
|
+
NK_INTERNAL void nk_dot_f64x2_init_v128relaxed(nk_dot_f64x2_state_v128relaxed_t *state) {
|
|
536
|
+
state->sum_f64x2 = wasm_f64x2_splat(0.0);
|
|
537
|
+
state->compensation_f64x2 = wasm_f64x2_splat(0.0);
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
NK_INTERNAL void nk_dot_f64x2_update_v128relaxed(nk_dot_f64x2_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
541
|
+
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
542
|
+
nk_unused_(depth_offset);
|
|
543
|
+
nk_unused_(active_dimensions);
|
|
544
|
+
v128_t product_f64x2 = wasm_f64x2_mul(a.v128, b.v128);
|
|
545
|
+
v128_t product_error_f64x2 = wasm_f64x2_sub(wasm_f64x2_relaxed_madd(a.v128, b.v128, wasm_f64x2_splat(0.0)),
|
|
546
|
+
product_f64x2);
|
|
547
|
+
v128_t tentative_sum_f64x2 = wasm_f64x2_add(state->sum_f64x2, product_f64x2);
|
|
548
|
+
v128_t virtual_addend_f64x2 = wasm_f64x2_sub(tentative_sum_f64x2, state->sum_f64x2);
|
|
549
|
+
v128_t sum_error_f64x2 = wasm_f64x2_add(
|
|
550
|
+
wasm_f64x2_sub(state->sum_f64x2, wasm_f64x2_sub(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
551
|
+
wasm_f64x2_sub(product_f64x2, virtual_addend_f64x2));
|
|
552
|
+
state->sum_f64x2 = tentative_sum_f64x2;
|
|
553
|
+
state->compensation_f64x2 = wasm_f64x2_add(state->compensation_f64x2,
|
|
554
|
+
wasm_f64x2_add(sum_error_f64x2, product_error_f64x2));
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
NK_INTERNAL void nk_dot_f64x2_finalize_v128relaxed( //
|
|
558
|
+
nk_dot_f64x2_state_v128relaxed_t const *state_a, nk_dot_f64x2_state_v128relaxed_t const *state_b, //
|
|
559
|
+
nk_dot_f64x2_state_v128relaxed_t const *state_c, nk_dot_f64x2_state_v128relaxed_t const *state_d, //
|
|
560
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
561
|
+
nk_unused_(total_dimensions);
|
|
562
|
+
result->f64s[0] = nk_dot_stable_sum_f64x2_v128relaxed_(state_a->sum_f64x2, state_a->compensation_f64x2);
|
|
563
|
+
result->f64s[1] = nk_dot_stable_sum_f64x2_v128relaxed_(state_b->sum_f64x2, state_b->compensation_f64x2);
|
|
564
|
+
result->f64s[2] = nk_dot_stable_sum_f64x2_v128relaxed_(state_c->sum_f64x2, state_c->compensation_f64x2);
|
|
565
|
+
result->f64s[3] = nk_dot_stable_sum_f64x2_v128relaxed_(state_d->sum_f64x2, state_d->compensation_f64x2);
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
NK_INTERNAL void nk_load_bf16x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst) {
|
|
569
|
+
nk_b64_vec_t raw;
|
|
570
|
+
nk_copy_bytes_(&raw, src, 8);
|
|
571
|
+
*dst = nk_bf16x4_to_f32x4_v128relaxed_(raw);
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
NK_INTERNAL void nk_partial_load_bf16x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
575
|
+
nk_b64_vec_t raw = {0};
|
|
576
|
+
nk_copy_bytes_(&raw, src, n * sizeof(nk_bf16_t));
|
|
577
|
+
*dst = nk_bf16x4_to_f32x4_v128relaxed_(raw);
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
NK_INTERNAL void nk_load_f16x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst) {
|
|
581
|
+
nk_b64_vec_t raw;
|
|
582
|
+
nk_copy_bytes_(&raw, src, 8);
|
|
583
|
+
*dst = nk_f16x4_to_f32x4_v128relaxed_(raw);
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
NK_INTERNAL void nk_partial_load_f16x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
587
|
+
nk_b64_vec_t raw = {0};
|
|
588
|
+
nk_copy_bytes_(&raw, src, n * sizeof(nk_f16_t));
|
|
589
|
+
*dst = nk_f16x4_to_f32x4_v128relaxed_(raw);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
typedef struct nk_dot_i8x16_state_v128relaxed_t {
|
|
593
|
+
v128_t product_sum_i32x4; // relaxed_dot accumulator
|
|
594
|
+
v128_t negative_sum_a_i32x4; // Σ(a[i] where b[i]<0), widened to i32
|
|
595
|
+
} nk_dot_i8x16_state_v128relaxed_t;
|
|
596
|
+
|
|
597
|
+
NK_INTERNAL void nk_dot_i8x16_init_v128relaxed(nk_dot_i8x16_state_v128relaxed_t *state) {
|
|
598
|
+
state->product_sum_i32x4 = wasm_i32x4_splat(0);
|
|
599
|
+
state->negative_sum_a_i32x4 = wasm_i32x4_splat(0);
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
NK_INTERNAL void nk_dot_i8x16_update_v128relaxed(nk_dot_i8x16_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
603
|
+
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
604
|
+
nk_unused_(depth_offset);
|
|
605
|
+
nk_unused_(active_dimensions);
|
|
606
|
+
// Bit-split: b = b_lo + (-128)·b_hi where b_lo = b & 0x7F ∈ [0,127], b_hi = b >> 7 ∈ {0,1}
|
|
607
|
+
// So a·b = a·b_lo − 128·a·b_hi, both operands fit i7 for relaxed_dot
|
|
608
|
+
v128_t b_lo_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
|
|
609
|
+
v128_t b_hi_u8x16 = wasm_u8x16_shr(b.v128, 7);
|
|
610
|
+
state->product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_lo_u8x16, state->product_sum_i32x4);
|
|
611
|
+
state->negative_sum_a_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_hi_u8x16,
|
|
612
|
+
state->negative_sum_a_i32x4);
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
NK_INTERNAL void nk_dot_i8x16_finalize_v128relaxed( //
|
|
616
|
+
nk_dot_i8x16_state_v128relaxed_t const *state_a, nk_dot_i8x16_state_v128relaxed_t const *state_b, //
|
|
617
|
+
nk_dot_i8x16_state_v128relaxed_t const *state_c, nk_dot_i8x16_state_v128relaxed_t const *state_d, //
|
|
618
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
619
|
+
nk_unused_(total_dimensions);
|
|
620
|
+
// For each state: result = reduce(product_sum) − 128 × reduce(negative_sum_a)
|
|
621
|
+
result->i32s[0] = nk_reduce_add_i32x4_v128relaxed_(state_a->product_sum_i32x4) -
|
|
622
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_a->negative_sum_a_i32x4);
|
|
623
|
+
result->i32s[1] = nk_reduce_add_i32x4_v128relaxed_(state_b->product_sum_i32x4) -
|
|
624
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_b->negative_sum_a_i32x4);
|
|
625
|
+
result->i32s[2] = nk_reduce_add_i32x4_v128relaxed_(state_c->product_sum_i32x4) -
|
|
626
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_c->negative_sum_a_i32x4);
|
|
627
|
+
result->i32s[3] = nk_reduce_add_i32x4_v128relaxed_(state_d->product_sum_i32x4) -
|
|
628
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_d->negative_sum_a_i32x4);
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
typedef struct nk_dot_u8x16_state_v128relaxed_t {
|
|
632
|
+
v128_t product_lo_i32x4; // relaxed_dot(a_signed, b_lo) accumulator
|
|
633
|
+
v128_t product_hi_i32x4; // relaxed_dot(a_signed, b_hi) accumulator
|
|
634
|
+
} nk_dot_u8x16_state_v128relaxed_t;
|
|
635
|
+
|
|
636
|
+
NK_INTERNAL void nk_dot_u8x16_init_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state) {
|
|
637
|
+
state->product_lo_i32x4 = wasm_i32x4_splat(0);
|
|
638
|
+
state->product_hi_i32x4 = wasm_i32x4_splat(0);
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
NK_INTERNAL void nk_dot_u8x16_update_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
642
|
+
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
643
|
+
nk_unused_(depth_offset);
|
|
644
|
+
nk_unused_(active_dimensions);
|
|
645
|
+
// Bit-split b: b = b_lo + 128·b_hi, with a_signed = a ^ 0x80 = a - 128 (reinterpret u8 as i8)
|
|
646
|
+
// Σ a·b = Σ(a_signed+128)·(b_lo+128·b_hi) = relaxed_dot(a_signed,b_lo) + 128·relaxed_dot(a_signed,b_hi) + 128·Σb
|
|
647
|
+
v128_t a_signed_i8x16 = wasm_v128_xor(a.v128, wasm_i8x16_splat((signed char)0x80));
|
|
648
|
+
v128_t b_lo_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
|
|
649
|
+
v128_t b_hi_u8x16 = wasm_u8x16_shr(b.v128, 7);
|
|
650
|
+
state->product_lo_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_lo_u8x16,
|
|
651
|
+
state->product_lo_i32x4);
|
|
652
|
+
state->product_hi_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_hi_u8x16,
|
|
653
|
+
state->product_hi_i32x4);
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
NK_INTERNAL void nk_dot_u8x16_finalize_v128relaxed( //
|
|
657
|
+
nk_dot_u8x16_state_v128relaxed_t const *state_a, nk_dot_u8x16_state_v128relaxed_t const *state_b, //
|
|
658
|
+
nk_dot_u8x16_state_v128relaxed_t const *state_c, nk_dot_u8x16_state_v128relaxed_t const *state_d, //
|
|
659
|
+
nk_size_t total_dimensions, nk_u32_t a_sum, nk_b128_vec_t b_sums, nk_b128_vec_t *result) {
|
|
660
|
+
nk_unused_(a_sum);
|
|
661
|
+
// Σ a·b = reduce(lo) + 128·reduce(hi) + 128·Σb
|
|
662
|
+
result->u32s[0] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_a->product_lo_i32x4) +
|
|
663
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_a->product_hi_i32x4) +
|
|
664
|
+
128 * (nk_i32_t)b_sums.u32s[0]);
|
|
665
|
+
result->u32s[1] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_b->product_lo_i32x4) +
|
|
666
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_b->product_hi_i32x4) +
|
|
667
|
+
128 * (nk_i32_t)b_sums.u32s[1]);
|
|
668
|
+
result->u32s[2] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_c->product_lo_i32x4) +
|
|
669
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_c->product_hi_i32x4) +
|
|
670
|
+
128 * (nk_i32_t)b_sums.u32s[2]);
|
|
671
|
+
result->u32s[3] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_d->product_lo_i32x4) +
|
|
672
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_d->product_hi_i32x4) +
|
|
673
|
+
128 * (nk_i32_t)b_sums.u32s[3]);
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
typedef struct nk_sum_u8x16_state_v128relaxed_t {
|
|
677
|
+
v128_t sum_u32x4;
|
|
678
|
+
} nk_sum_u8x16_state_v128relaxed_t;
|
|
679
|
+
|
|
680
|
+
NK_INTERNAL void nk_sum_u8x16_init_v128relaxed(nk_sum_u8x16_state_v128relaxed_t *state) {
|
|
681
|
+
state->sum_u32x4 = wasm_i32x4_splat(0);
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
NK_INTERNAL void nk_sum_u8x16_update_v128relaxed(nk_sum_u8x16_state_v128relaxed_t *state, nk_b128_vec_t v) {
|
|
685
|
+
v128_t sum_u16x8 = wasm_u16x8_extadd_pairwise_u8x16(v.v128);
|
|
686
|
+
v128_t sum_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(sum_u16x8);
|
|
687
|
+
state->sum_u32x4 = wasm_i32x4_add(state->sum_u32x4, sum_u32x4);
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
NK_INTERNAL nk_u32_t nk_sum_u8x16_finalize_v128relaxed(nk_sum_u8x16_state_v128relaxed_t const *state, nk_size_t count) {
|
|
691
|
+
nk_unused_(count);
|
|
692
|
+
return nk_reduce_add_u32x4_v128relaxed_(state->sum_u32x4);
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
typedef struct nk_dot_e2m3x16_state_v128relaxed_t {
|
|
696
|
+
v128_t sum_i32x4; // relaxed_dot accumulator (a_signed × b_unsigned)
|
|
697
|
+
} nk_dot_e2m3x16_state_v128relaxed_t;
|
|
698
|
+
|
|
699
|
+
NK_INTERNAL void nk_dot_e2m3x16_init_v128relaxed(nk_dot_e2m3x16_state_v128relaxed_t *state) {
|
|
700
|
+
state->sum_i32x4 = wasm_i32x4_splat(0);
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
NK_INTERNAL void nk_dot_e2m3x16_update_v128relaxed(nk_dot_e2m3x16_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
704
|
+
nk_b128_vec_t b, nk_size_t depth_offset,
|
|
705
|
+
nk_size_t active_dimensions) {
|
|
706
|
+
nk_unused_(depth_offset);
|
|
707
|
+
nk_unused_(active_dimensions);
|
|
708
|
+
// Same LUT-based approach as 1:1 dot, accumulating into state
|
|
709
|
+
v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
|
|
710
|
+
v128_t lut_upper_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
|
|
711
|
+
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
712
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
713
|
+
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
714
|
+
v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
|
|
715
|
+
|
|
716
|
+
// Extract magnitude indices
|
|
717
|
+
v128_t a_mag_u8x16 = wasm_v128_and(a.v128, magnitude_mask_u8x16);
|
|
718
|
+
v128_t b_mag_u8x16 = wasm_v128_and(b.v128, magnitude_mask_u8x16);
|
|
719
|
+
|
|
720
|
+
// Dual swizzle LUT for a
|
|
721
|
+
v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
|
|
722
|
+
v128_t a_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_idx_u8x16);
|
|
723
|
+
v128_t a_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_idx_u8x16);
|
|
724
|
+
v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
725
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_hi_u8x16, a_lo_u8x16, a_sel_u8x16);
|
|
726
|
+
|
|
727
|
+
// Dual swizzle LUT for b
|
|
728
|
+
v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
|
|
729
|
+
v128_t b_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_idx_u8x16);
|
|
730
|
+
v128_t b_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_idx_u8x16);
|
|
731
|
+
v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
732
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_hi_u8x16, b_lo_u8x16, b_sel_u8x16);
|
|
733
|
+
|
|
734
|
+
// Combined sign → apply to a (relaxed_dot wants i8 × u7)
|
|
735
|
+
v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
|
|
736
|
+
v128_t neg_mask_u8x16 = wasm_i8x16_eq(sign_u8x16, sign_mask_u8x16);
|
|
737
|
+
v128_t a_neg_u8x16 = wasm_i8x16_neg(a_unsigned_u8x16);
|
|
738
|
+
v128_t a_signed_i8x16 = wasm_i8x16_relaxed_laneselect(a_neg_u8x16, a_unsigned_u8x16, neg_mask_u8x16);
|
|
739
|
+
|
|
740
|
+
// relaxed_dot accumulate
|
|
741
|
+
state->sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_unsigned_u8x16, state->sum_i32x4);
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
NK_INTERNAL void nk_dot_e2m3x16_finalize_v128relaxed( //
|
|
745
|
+
nk_dot_e2m3x16_state_v128relaxed_t const *state_a, nk_dot_e2m3x16_state_v128relaxed_t const *state_b, //
|
|
746
|
+
nk_dot_e2m3x16_state_v128relaxed_t const *state_c, nk_dot_e2m3x16_state_v128relaxed_t const *state_d, //
|
|
747
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
748
|
+
nk_unused_(total_dimensions);
|
|
749
|
+
// Standard 4-way reduce, divide by 256.0f (LUT values are scaled ×16 for each operand)
|
|
750
|
+
nk_f32_t inv_256 = 1.0f / 256.0f;
|
|
751
|
+
result->f32s[0] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_a->sum_i32x4) * inv_256;
|
|
752
|
+
result->f32s[1] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_b->sum_i32x4) * inv_256;
|
|
753
|
+
result->f32s[2] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_c->sum_i32x4) * inv_256;
|
|
754
|
+
result->f32s[3] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_d->sum_i32x4) * inv_256;
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
typedef struct nk_dot_e3m2x16_state_v128relaxed_t {
|
|
758
|
+
v128_t sum_i32x4;
|
|
759
|
+
} nk_dot_e3m2x16_state_v128relaxed_t;
|
|
760
|
+
|
|
761
|
+
NK_INTERNAL void nk_dot_e3m2x16_init_v128relaxed(nk_dot_e3m2x16_state_v128relaxed_t *state) {
|
|
762
|
+
state->sum_i32x4 = wasm_i32x4_splat(0);
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
NK_INTERNAL void nk_dot_e3m2x16_update_v128relaxed(nk_dot_e3m2x16_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
766
|
+
nk_b128_vec_t b, nk_size_t depth_offset,
|
|
767
|
+
nk_size_t active_dimensions) {
|
|
768
|
+
nk_unused_(depth_offset);
|
|
769
|
+
nk_unused_(active_dimensions);
|
|
770
|
+
// ×4 scaled LUT — all values ≤ 112, fits u7 for relaxed_dot
|
|
771
|
+
// Indices 0-11 rounded to nearest integer (max error ±0.5 in ×4 domain = ±0.125 in value)
|
|
772
|
+
// Indices 12-31 exact
|
|
773
|
+
v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 6, 7);
|
|
774
|
+
v128_t lut_upper_u8x16 = wasm_i8x16_const(8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112);
|
|
775
|
+
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
776
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
777
|
+
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
778
|
+
v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
|
|
779
|
+
|
|
780
|
+
v128_t a_mag_u8x16 = wasm_v128_and(a.v128, magnitude_mask_u8x16);
|
|
781
|
+
v128_t b_mag_u8x16 = wasm_v128_and(b.v128, magnitude_mask_u8x16);
|
|
782
|
+
|
|
783
|
+
// Dual swizzle LUT for a
|
|
784
|
+
v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
|
|
785
|
+
v128_t a_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_idx_u8x16);
|
|
786
|
+
v128_t a_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_idx_u8x16);
|
|
787
|
+
v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
788
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_hi_u8x16, a_lo_u8x16, a_sel_u8x16);
|
|
789
|
+
|
|
790
|
+
// Dual swizzle LUT for b
|
|
791
|
+
v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
|
|
792
|
+
v128_t b_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_idx_u8x16);
|
|
793
|
+
v128_t b_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_idx_u8x16);
|
|
794
|
+
v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
795
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_hi_u8x16, b_lo_u8x16, b_sel_u8x16);
|
|
796
|
+
|
|
797
|
+
// Combined sign → apply to a (relaxed_dot wants i8 × u7)
|
|
798
|
+
v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
|
|
799
|
+
v128_t neg_mask_u8x16 = wasm_i8x16_eq(sign_u8x16, sign_mask_u8x16);
|
|
800
|
+
v128_t a_neg_u8x16 = wasm_i8x16_neg(a_unsigned_u8x16);
|
|
801
|
+
v128_t a_signed_i8x16 = wasm_i8x16_relaxed_laneselect(a_neg_u8x16, a_unsigned_u8x16, neg_mask_u8x16);
|
|
802
|
+
|
|
803
|
+
state->sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_unsigned_u8x16, state->sum_i32x4);
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
NK_INTERNAL void nk_dot_e3m2x16_finalize_v128relaxed( //
|
|
807
|
+
nk_dot_e3m2x16_state_v128relaxed_t const *state_a, nk_dot_e3m2x16_state_v128relaxed_t const *state_b, //
|
|
808
|
+
nk_dot_e3m2x16_state_v128relaxed_t const *state_c, nk_dot_e3m2x16_state_v128relaxed_t const *state_d, //
|
|
809
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
810
|
+
nk_unused_(total_dimensions);
|
|
811
|
+
// ×4 per operand → ×16 product scaling → divide by 16
|
|
812
|
+
nk_f32_t inv_16 = 1.0f / 16.0f;
|
|
813
|
+
result->f32s[0] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_a->sum_i32x4) * inv_16;
|
|
814
|
+
result->f32s[1] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_b->sum_i32x4) * inv_16;
|
|
815
|
+
result->f32s[2] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_c->sum_i32x4) * inv_16;
|
|
816
|
+
result->f32s[3] = (nk_f32_t)nk_reduce_add_i32x4_v128relaxed_(state_d->sum_i32x4) * inv_16;
|
|
817
|
+
}
|
|
818
|
+
|
|
819
|
+
typedef struct nk_dot_through_f32x4_state_v128relaxed_t_ nk_dot_e4m3x4_state_v128relaxed_t;
|
|
820
|
+
typedef struct nk_dot_through_f32x4_state_v128relaxed_t_ nk_dot_e5m2x4_state_v128relaxed_t;
|
|
821
|
+
|
|
822
|
+
NK_INTERNAL void nk_load_e4m3x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst) {
|
|
823
|
+
nk_b32_vec_t raw;
|
|
824
|
+
nk_copy_bytes_(&raw, src, 4);
|
|
825
|
+
*dst = nk_e4m3x4_to_f32x4_v128relaxed_(raw);
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
NK_INTERNAL void nk_partial_load_e4m3x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
829
|
+
nk_b32_vec_t raw = {0};
|
|
830
|
+
nk_copy_bytes_(&raw, src, n * sizeof(nk_e4m3_t));
|
|
831
|
+
*dst = nk_e4m3x4_to_f32x4_v128relaxed_(raw);
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
NK_INTERNAL void nk_load_e5m2x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst) {
|
|
835
|
+
nk_b32_vec_t raw;
|
|
836
|
+
nk_copy_bytes_(&raw, src, 4);
|
|
837
|
+
*dst = nk_e5m2x4_to_f32x4_v128relaxed_(raw);
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
NK_INTERNAL void nk_partial_load_e5m2x4_to_f32x4_v128relaxed_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
841
|
+
nk_b32_vec_t raw = {0};
|
|
842
|
+
nk_copy_bytes_(&raw, src, n * sizeof(nk_e5m2_t));
|
|
843
|
+
*dst = nk_e5m2x4_to_f32x4_v128relaxed_(raw);
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
NK_PUBLIC void nk_dot_e4m3_v128relaxed(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
847
|
+
nk_f32_t *result) {
|
|
848
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
849
|
+
|
|
850
|
+
nk_dot_e4m3_v128relaxed_cycle:
|
|
851
|
+
if (count_scalars < 4) {
|
|
852
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
853
|
+
nk_partial_load_e4m3x4_to_f32x4_v128relaxed_(a_scalars, &a_vec, count_scalars);
|
|
854
|
+
nk_partial_load_e4m3x4_to_f32x4_v128relaxed_(b_scalars, &b_vec, count_scalars);
|
|
855
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_vec.v128, b_vec.v128, sum_f32x4);
|
|
856
|
+
count_scalars = 0;
|
|
857
|
+
}
|
|
858
|
+
else {
|
|
859
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
860
|
+
nk_load_e4m3x4_to_f32x4_v128relaxed_(a_scalars, &a_vec);
|
|
861
|
+
nk_load_e4m3x4_to_f32x4_v128relaxed_(b_scalars, &b_vec);
|
|
862
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_vec.v128, b_vec.v128, sum_f32x4);
|
|
863
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
864
|
+
}
|
|
865
|
+
if (count_scalars) goto nk_dot_e4m3_v128relaxed_cycle;
|
|
866
|
+
|
|
867
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
NK_PUBLIC void nk_dot_e5m2_v128relaxed(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
871
|
+
nk_f32_t *result) {
|
|
872
|
+
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
873
|
+
|
|
874
|
+
nk_dot_e5m2_v128relaxed_cycle:
|
|
875
|
+
if (count_scalars < 4) {
|
|
876
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
877
|
+
nk_partial_load_e5m2x4_to_f32x4_v128relaxed_(a_scalars, &a_vec, count_scalars);
|
|
878
|
+
nk_partial_load_e5m2x4_to_f32x4_v128relaxed_(b_scalars, &b_vec, count_scalars);
|
|
879
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_vec.v128, b_vec.v128, sum_f32x4);
|
|
880
|
+
count_scalars = 0;
|
|
881
|
+
}
|
|
882
|
+
else {
|
|
883
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
884
|
+
nk_load_e5m2x4_to_f32x4_v128relaxed_(a_scalars, &a_vec);
|
|
885
|
+
nk_load_e5m2x4_to_f32x4_v128relaxed_(b_scalars, &b_vec);
|
|
886
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_vec.v128, b_vec.v128, sum_f32x4);
|
|
887
|
+
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
888
|
+
}
|
|
889
|
+
if (count_scalars) goto nk_dot_e5m2_v128relaxed_cycle;
|
|
890
|
+
|
|
891
|
+
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
NK_PUBLIC void nk_dot_u4_v128relaxed(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
895
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
896
|
+
nk_size_t n_bytes = n / 2;
|
|
897
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
898
|
+
v128_t sum_i32x4 = wasm_i32x4_splat(0);
|
|
899
|
+
v128_t a_u4x32, b_u4x32;
|
|
900
|
+
|
|
901
|
+
nk_dot_u4_v128relaxed_cycle:
|
|
902
|
+
if (n_bytes < 16) {
|
|
903
|
+
nk_b128_vec_t a_vec = {0}, b_vec = {0};
|
|
904
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n_bytes);
|
|
905
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n_bytes);
|
|
906
|
+
a_u4x32 = a_vec.v128;
|
|
907
|
+
b_u4x32 = b_vec.v128;
|
|
908
|
+
n_bytes = 0;
|
|
909
|
+
}
|
|
910
|
+
else {
|
|
911
|
+
a_u4x32 = wasm_v128_load(a);
|
|
912
|
+
b_u4x32 = wasm_v128_load(b);
|
|
913
|
+
a = (nk_u4x2_t const *)((nk_u8_t const *)a + 16);
|
|
914
|
+
b = (nk_u4x2_t const *)((nk_u8_t const *)b + 16);
|
|
915
|
+
n_bytes -= 16;
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
// Extract low and high nibbles
|
|
919
|
+
v128_t a_low_u8x16 = wasm_v128_and(a_u4x32, nibble_mask_u8x16);
|
|
920
|
+
v128_t a_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(a_u4x32, 4), nibble_mask_u8x16);
|
|
921
|
+
v128_t b_low_u8x16 = wasm_v128_and(b_u4x32, nibble_mask_u8x16);
|
|
922
|
+
v128_t b_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(b_u4x32, 4), nibble_mask_u8x16);
|
|
923
|
+
|
|
924
|
+
// Values in [0,15] fit u7 slot directly
|
|
925
|
+
sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_low_u8x16, b_low_u8x16, sum_i32x4);
|
|
926
|
+
sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_high_u8x16, b_high_u8x16, sum_i32x4);
|
|
927
|
+
if (n_bytes) goto nk_dot_u4_v128relaxed_cycle;
|
|
928
|
+
|
|
929
|
+
*result = (nk_u32_t)nk_reduce_add_i32x4_v128relaxed_(sum_i32x4);
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
typedef struct nk_dot_u4x32_state_v128relaxed_t {
|
|
933
|
+
v128_t sum_i32x4;
|
|
934
|
+
} nk_dot_u4x32_state_v128relaxed_t;
|
|
935
|
+
|
|
936
|
+
NK_INTERNAL void nk_dot_u4x32_init_v128relaxed(nk_dot_u4x32_state_v128relaxed_t *state) {
|
|
937
|
+
state->sum_i32x4 = wasm_i32x4_splat(0);
|
|
938
|
+
}
|
|
939
|
+
|
|
940
|
+
NK_INTERNAL void nk_dot_u4x32_update_v128relaxed(nk_dot_u4x32_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
941
|
+
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
942
|
+
nk_unused_(depth_offset);
|
|
943
|
+
nk_unused_(active_dimensions);
|
|
944
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
945
|
+
v128_t a_low_u8x16 = wasm_v128_and(a.v128, nibble_mask_u8x16);
|
|
946
|
+
v128_t a_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(a.v128, 4), nibble_mask_u8x16);
|
|
947
|
+
v128_t b_low_u8x16 = wasm_v128_and(b.v128, nibble_mask_u8x16);
|
|
948
|
+
v128_t b_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(b.v128, 4), nibble_mask_u8x16);
|
|
949
|
+
state->sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_low_u8x16, b_low_u8x16, state->sum_i32x4);
|
|
950
|
+
state->sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_high_u8x16, b_high_u8x16, state->sum_i32x4);
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
NK_INTERNAL void nk_dot_u4x32_finalize_v128relaxed( //
|
|
954
|
+
nk_dot_u4x32_state_v128relaxed_t const *state_a, nk_dot_u4x32_state_v128relaxed_t const *state_b, //
|
|
955
|
+
nk_dot_u4x32_state_v128relaxed_t const *state_c, nk_dot_u4x32_state_v128relaxed_t const *state_d, //
|
|
956
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
957
|
+
nk_unused_(total_dimensions);
|
|
958
|
+
result->u32s[0] = (nk_u32_t)nk_reduce_add_i32x4_v128relaxed_(state_a->sum_i32x4);
|
|
959
|
+
result->u32s[1] = (nk_u32_t)nk_reduce_add_i32x4_v128relaxed_(state_b->sum_i32x4);
|
|
960
|
+
result->u32s[2] = (nk_u32_t)nk_reduce_add_i32x4_v128relaxed_(state_c->sum_i32x4);
|
|
961
|
+
result->u32s[3] = (nk_u32_t)nk_reduce_add_i32x4_v128relaxed_(state_d->sum_i32x4);
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
NK_PUBLIC void nk_dot_i4_v128relaxed(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
965
|
+
n = nk_size_round_up_to_multiple_(n, 2);
|
|
966
|
+
nk_size_t n_bytes = n / 2;
|
|
967
|
+
nk_u8_t const *a_bytes = (nk_u8_t const *)a;
|
|
968
|
+
nk_u8_t const *b_bytes = (nk_u8_t const *)b;
|
|
969
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
970
|
+
v128_t bias_mask_u8x16 = wasm_u8x16_splat(0x08);
|
|
971
|
+
nk_i64_t cd_total = 0, cx_total = 0, dx_total = 0;
|
|
972
|
+
nk_size_t i = 0;
|
|
973
|
+
|
|
974
|
+
// Windowed accumulation loop
|
|
975
|
+
while (i + 16 <= n_bytes) {
|
|
976
|
+
v128_t sum_cd_i32x4 = wasm_i32x4_splat(0);
|
|
977
|
+
v128_t sum_cx_low_u16x8 = wasm_u16x8_splat(0);
|
|
978
|
+
v128_t sum_cx_high_u16x8 = wasm_u16x8_splat(0);
|
|
979
|
+
v128_t sum_dx_low_u16x8 = wasm_u16x8_splat(0);
|
|
980
|
+
v128_t sum_dx_high_u16x8 = wasm_u16x8_splat(0);
|
|
981
|
+
|
|
982
|
+
// Inner loop: accumulate 128 iterations before widening
|
|
983
|
+
// Overflow safety: max u16 lane = 128 × 30 = 3840 < 65535
|
|
984
|
+
nk_size_t cycle = 0;
|
|
985
|
+
for (; cycle < 128 && i + 16 <= n_bytes; ++cycle, i += 16) {
|
|
986
|
+
v128_t a_i4x32 = wasm_v128_load(a_bytes + i);
|
|
987
|
+
v128_t b_i4x32 = wasm_v128_load(b_bytes + i);
|
|
988
|
+
|
|
989
|
+
// Extract nibbles
|
|
990
|
+
v128_t a_low_u8x16 = wasm_v128_and(a_i4x32, nibble_mask_u8x16);
|
|
991
|
+
v128_t a_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(a_i4x32, 4), nibble_mask_u8x16);
|
|
992
|
+
v128_t b_low_u8x16 = wasm_v128_and(b_i4x32, nibble_mask_u8x16);
|
|
993
|
+
v128_t b_high_u8x16 = wasm_v128_and(wasm_u16x8_shr(b_i4x32, 4), nibble_mask_u8x16);
|
|
994
|
+
|
|
995
|
+
// XOR with 8 to get biased values cx, dx in [0,15]
|
|
996
|
+
v128_t c_low_u8x16 = wasm_v128_xor(a_low_u8x16, bias_mask_u8x16);
|
|
997
|
+
v128_t c_high_u8x16 = wasm_v128_xor(a_high_u8x16, bias_mask_u8x16);
|
|
998
|
+
v128_t d_low_u8x16 = wasm_v128_xor(b_low_u8x16, bias_mask_u8x16);
|
|
999
|
+
v128_t d_high_u8x16 = wasm_v128_xor(b_high_u8x16, bias_mask_u8x16);
|
|
1000
|
+
|
|
1001
|
+
// Compute biased dot products
|
|
1002
|
+
sum_cd_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(c_low_u8x16, d_low_u8x16, sum_cd_i32x4);
|
|
1003
|
+
sum_cd_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(c_high_u8x16, d_high_u8x16, sum_cd_i32x4);
|
|
1004
|
+
|
|
1005
|
+
// Accumulate sums in u16 (1 widening/iter instead of 2)
|
|
1006
|
+
sum_cx_low_u16x8 = wasm_i16x8_add(sum_cx_low_u16x8, wasm_u16x8_extadd_pairwise_u8x16(c_low_u8x16));
|
|
1007
|
+
sum_cx_high_u16x8 = wasm_i16x8_add(sum_cx_high_u16x8, wasm_u16x8_extadd_pairwise_u8x16(c_high_u8x16));
|
|
1008
|
+
sum_dx_low_u16x8 = wasm_i16x8_add(sum_dx_low_u16x8, wasm_u16x8_extadd_pairwise_u8x16(d_low_u8x16));
|
|
1009
|
+
sum_dx_high_u16x8 = wasm_i16x8_add(sum_dx_high_u16x8, wasm_u16x8_extadd_pairwise_u8x16(d_high_u8x16));
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
// Deferred widening: u16 → u32 once per window
|
|
1013
|
+
v128_t sum_cx_u32x4 = wasm_i32x4_add(wasm_u32x4_extadd_pairwise_u16x8(sum_cx_low_u16x8),
|
|
1014
|
+
wasm_u32x4_extadd_pairwise_u16x8(sum_cx_high_u16x8));
|
|
1015
|
+
v128_t sum_dx_u32x4 = wasm_i32x4_add(wasm_u32x4_extadd_pairwise_u16x8(sum_dx_low_u16x8),
|
|
1016
|
+
wasm_u32x4_extadd_pairwise_u16x8(sum_dx_high_u16x8));
|
|
1017
|
+
cd_total += nk_reduce_add_i32x4_v128relaxed_(sum_cd_i32x4);
|
|
1018
|
+
cx_total += nk_reduce_add_u32x4_v128relaxed_(sum_cx_u32x4);
|
|
1019
|
+
dx_total += nk_reduce_add_u32x4_v128relaxed_(sum_dx_u32x4);
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
// SIMD portion covers i*2 elements (2 nibbles per byte)
|
|
1023
|
+
nk_i64_t n_simd_elements = (nk_i64_t)i * 2;
|
|
1024
|
+
|
|
1025
|
+
// Scalar tail: use signed helpers directly (no bias/correction needed)
|
|
1026
|
+
nk_i64_t tail_dot = 0;
|
|
1027
|
+
for (; i < n_bytes; i++) {
|
|
1028
|
+
nk_i4x2_t ai = ((nk_i4x2_t const *)a_bytes)[i];
|
|
1029
|
+
nk_i4x2_t bi = ((nk_i4x2_t const *)b_bytes)[i];
|
|
1030
|
+
tail_dot += (nk_i32_t)nk_i4x2_low_(ai) * (nk_i32_t)nk_i4x2_low_(bi) +
|
|
1031
|
+
(nk_i32_t)nk_i4x2_high_(ai) * (nk_i32_t)nk_i4x2_high_(bi);
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
// Apply algebraic correction to SIMD portion only, add unbiased scalar tail
|
|
1035
|
+
*result = (nk_i32_t)(cd_total - 8 * (cx_total + dx_total) + 64 * n_simd_elements + tail_dot);
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
typedef struct nk_dot_i4x32_state_v128relaxed_t {
|
|
1039
|
+
v128_t biased_product_sum_i32x4;
|
|
1040
|
+
} nk_dot_i4x32_state_v128relaxed_t;
|
|
1041
|
+
|
|
1042
|
+
NK_INTERNAL void nk_dot_i4x32_init_v128relaxed(nk_dot_i4x32_state_v128relaxed_t *state) {
|
|
1043
|
+
state->biased_product_sum_i32x4 = wasm_i32x4_splat(0);
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
NK_INTERNAL void nk_dot_i4x32_update_v128relaxed(nk_dot_i4x32_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
1047
|
+
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
1048
|
+
nk_unused_(depth_offset);
|
|
1049
|
+
nk_unused_(active_dimensions);
|
|
1050
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
1051
|
+
v128_t bias_mask_u8x16 = wasm_u8x16_splat(0x08);
|
|
1052
|
+
v128_t a_low_u8x16 = wasm_v128_xor(wasm_v128_and(a.v128, nibble_mask_u8x16), bias_mask_u8x16);
|
|
1053
|
+
v128_t a_high_u8x16 = wasm_v128_xor(wasm_v128_and(wasm_u16x8_shr(a.v128, 4), nibble_mask_u8x16), bias_mask_u8x16);
|
|
1054
|
+
v128_t b_low_u8x16 = wasm_v128_xor(wasm_v128_and(b.v128, nibble_mask_u8x16), bias_mask_u8x16);
|
|
1055
|
+
v128_t b_high_u8x16 = wasm_v128_xor(wasm_v128_and(wasm_u16x8_shr(b.v128, 4), nibble_mask_u8x16), bias_mask_u8x16);
|
|
1056
|
+
state->biased_product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_low_u8x16, b_low_u8x16,
|
|
1057
|
+
state->biased_product_sum_i32x4);
|
|
1058
|
+
state->biased_product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_high_u8x16, b_high_u8x16,
|
|
1059
|
+
state->biased_product_sum_i32x4);
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
NK_INTERNAL void nk_dot_i4x32_finalize_v128relaxed( //
|
|
1063
|
+
nk_dot_i4x32_state_v128relaxed_t const *state_a, nk_dot_i4x32_state_v128relaxed_t const *state_b, //
|
|
1064
|
+
nk_dot_i4x32_state_v128relaxed_t const *state_c, nk_dot_i4x32_state_v128relaxed_t const *state_d, //
|
|
1065
|
+
nk_size_t total_dimensions, //
|
|
1066
|
+
nk_i32_t a_sum, /* Row sum of A (signed sum of i4 values) */ //
|
|
1067
|
+
nk_b128_vec_t b_sums, /* 4 × i32 column sums of B */ //
|
|
1068
|
+
nk_b128_vec_t *result) {
|
|
1069
|
+
// Match x86 compensated i4 finalizers: result = biased_dot - 8*(a_sum + b_sum) - 64*depth_padded
|
|
1070
|
+
nk_i64_t depth_padded = (nk_i64_t)nk_size_round_up_to_multiple_(total_dimensions, 32);
|
|
1071
|
+
result->i32s[0] = nk_reduce_add_i32x4_v128relaxed_(state_a->biased_product_sum_i32x4) -
|
|
1072
|
+
8 * ((nk_i64_t)a_sum + (nk_i64_t)b_sums.i32s[0]) - 64 * depth_padded;
|
|
1073
|
+
result->i32s[1] = nk_reduce_add_i32x4_v128relaxed_(state_b->biased_product_sum_i32x4) -
|
|
1074
|
+
8 * ((nk_i64_t)a_sum + (nk_i64_t)b_sums.i32s[1]) - 64 * depth_padded;
|
|
1075
|
+
result->i32s[2] = nk_reduce_add_i32x4_v128relaxed_(state_c->biased_product_sum_i32x4) -
|
|
1076
|
+
8 * ((nk_i64_t)a_sum + (nk_i64_t)b_sums.i32s[2]) - 64 * depth_padded;
|
|
1077
|
+
result->i32s[3] = nk_reduce_add_i32x4_v128relaxed_(state_d->biased_product_sum_i32x4) -
|
|
1078
|
+
8 * ((nk_i64_t)a_sum + (nk_i64_t)b_sums.i32s[3]) - 64 * depth_padded;
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
typedef struct nk_sum_i4x32_state_v128relaxed_t {
|
|
1082
|
+
v128_t sum_i32x4;
|
|
1083
|
+
} nk_sum_i4x32_state_v128relaxed_t;
|
|
1084
|
+
|
|
1085
|
+
NK_INTERNAL void nk_sum_i4x32_init_v128relaxed(nk_sum_i4x32_state_v128relaxed_t *state) {
|
|
1086
|
+
state->sum_i32x4 = wasm_i32x4_splat(0);
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
NK_INTERNAL void nk_sum_i4x32_update_v128relaxed(nk_sum_i4x32_state_v128relaxed_t *state, nk_b128_vec_t v) {
|
|
1090
|
+
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
1091
|
+
v128_t bias_mask_u8x16 = wasm_u8x16_splat(0x08);
|
|
1092
|
+
v128_t low_u8x16 = wasm_v128_xor(wasm_v128_and(v.v128, nibble_mask_u8x16), bias_mask_u8x16);
|
|
1093
|
+
v128_t high_u8x16 = wasm_v128_xor(wasm_v128_and(wasm_u16x8_shr(v.v128, 4), nibble_mask_u8x16), bias_mask_u8x16);
|
|
1094
|
+
v128_t sum_low_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(wasm_u16x8_extadd_pairwise_u8x16(low_u8x16));
|
|
1095
|
+
v128_t sum_high_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(wasm_u16x8_extadd_pairwise_u8x16(high_u8x16));
|
|
1096
|
+
v128_t signed_sum_i32x4 = wasm_i32x4_sub(wasm_i32x4_add(sum_low_u32x4, sum_high_u32x4), wasm_i32x4_splat(64));
|
|
1097
|
+
state->sum_i32x4 = wasm_i32x4_add(state->sum_i32x4, signed_sum_i32x4);
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_v128relaxed(nk_sum_i4x32_state_v128relaxed_t const *state, nk_size_t count) {
|
|
1101
|
+
nk_unused_(count);
|
|
1102
|
+
return nk_reduce_add_i32x4_v128relaxed_(state->sum_i32x4);
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
NK_PUBLIC void nk_dot_f32c_v128relaxed(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
1106
|
+
nk_f64c_t *result) {
|
|
1107
|
+
v128_t sum_real_f64x2 = wasm_f64x2_splat(0.0);
|
|
1108
|
+
v128_t sum_imag_f64x2 = wasm_f64x2_splat(0.0);
|
|
1109
|
+
v128_t sign_flip_i64x2 = wasm_i64x2_const(0, 0x8000000000000000ULL);
|
|
1110
|
+
|
|
1111
|
+
nk_size_t idx_pairs = 0;
|
|
1112
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
1113
|
+
// Load [real, imag] as 64 bits, promote to f64x2
|
|
1114
|
+
v128_t a_f32x2 = wasm_v128_load64_zero((nk_f32_t const *)(a_pairs + idx_pairs));
|
|
1115
|
+
v128_t b_f32x2 = wasm_v128_load64_zero((nk_f32_t const *)(b_pairs + idx_pairs));
|
|
1116
|
+
v128_t a_real_imag_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
1117
|
+
v128_t b_real_imag_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
1118
|
+
|
|
1119
|
+
// Swap b: [imag, real]
|
|
1120
|
+
v128_t b_swapped_f64x2 = wasm_i64x2_shuffle(b_real_imag_f64x2, b_real_imag_f64x2, 1, 0);
|
|
1121
|
+
|
|
1122
|
+
// Accumulate: real part uses a*b directly, imag part uses a*b_swapped
|
|
1123
|
+
sum_real_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_real_imag_f64x2, sum_real_f64x2);
|
|
1124
|
+
sum_imag_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_swapped_f64x2, sum_imag_f64x2);
|
|
1125
|
+
}
|
|
1126
|
+
|
|
1127
|
+
// Flip sign of lane 1 in sum_real: real = Σ(aᵣ*bᵣ) - Σ(aᵢ*bᵢ)
|
|
1128
|
+
sum_real_f64x2 = wasm_v128_xor(sum_real_f64x2, sign_flip_i64x2);
|
|
1129
|
+
|
|
1130
|
+
// Finalize: real = sum_real[0] + sum_real[1], imag = sum_imag[0] + sum_imag[1]
|
|
1131
|
+
nk_f64_t real_part = wasm_f64x2_extract_lane(sum_real_f64x2, 0) + wasm_f64x2_extract_lane(sum_real_f64x2, 1);
|
|
1132
|
+
nk_f64_t imag_part = wasm_f64x2_extract_lane(sum_imag_f64x2, 0) + wasm_f64x2_extract_lane(sum_imag_f64x2, 1);
|
|
1133
|
+
result->real = real_part;
|
|
1134
|
+
result->imag = imag_part;
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
NK_PUBLIC void nk_vdot_f32c_v128relaxed(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
1138
|
+
nk_f64c_t *result) {
|
|
1139
|
+
v128_t sum_real_f64x2 = wasm_f64x2_splat(0.0);
|
|
1140
|
+
v128_t sum_imag_f64x2 = wasm_f64x2_splat(0.0);
|
|
1141
|
+
v128_t sign_flip_i64x2 = wasm_i64x2_const(0, 0x8000000000000000ULL);
|
|
1142
|
+
|
|
1143
|
+
nk_size_t idx_pairs = 0;
|
|
1144
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
1145
|
+
v128_t a_f32x2 = wasm_v128_load64_zero((nk_f32_t const *)(a_pairs + idx_pairs));
|
|
1146
|
+
v128_t b_f32x2 = wasm_v128_load64_zero((nk_f32_t const *)(b_pairs + idx_pairs));
|
|
1147
|
+
v128_t a_real_imag_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
1148
|
+
v128_t b_real_imag_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
1149
|
+
v128_t b_swapped_f64x2 = wasm_i64x2_shuffle(b_real_imag_f64x2, b_real_imag_f64x2, 1, 0);
|
|
1150
|
+
|
|
1151
|
+
sum_real_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_real_imag_f64x2, sum_real_f64x2);
|
|
1152
|
+
sum_imag_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_swapped_f64x2, sum_imag_f64x2);
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
// For vdot (conjugate dot): flip sign of imag lane 1 instead
|
|
1156
|
+
sum_imag_f64x2 = wasm_v128_xor(sum_imag_f64x2, sign_flip_i64x2);
|
|
1157
|
+
|
|
1158
|
+
nk_f64_t real_part = wasm_f64x2_extract_lane(sum_real_f64x2, 0) + wasm_f64x2_extract_lane(sum_real_f64x2, 1);
|
|
1159
|
+
nk_f64_t imag_part = wasm_f64x2_extract_lane(sum_imag_f64x2, 0) + wasm_f64x2_extract_lane(sum_imag_f64x2, 1);
|
|
1160
|
+
result->real = real_part;
|
|
1161
|
+
result->imag = imag_part;
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1164
|
+
NK_PUBLIC void nk_dot_f64c_v128relaxed(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
1165
|
+
nk_f64c_t *result) {
|
|
1166
|
+
v128_t sum_real_f64x2 = wasm_f64x2_splat(0.0);
|
|
1167
|
+
v128_t sum_imag_f64x2 = wasm_f64x2_splat(0.0);
|
|
1168
|
+
v128_t sign_flip_i64x2 = wasm_i64x2_const(0, 0x8000000000000000ULL);
|
|
1169
|
+
|
|
1170
|
+
nk_size_t idx_pairs = 0;
|
|
1171
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
1172
|
+
v128_t a_real_imag_f64x2 = wasm_v128_load((nk_f64_t const *)(a_pairs + idx_pairs));
|
|
1173
|
+
v128_t b_real_imag_f64x2 = wasm_v128_load((nk_f64_t const *)(b_pairs + idx_pairs));
|
|
1174
|
+
v128_t b_swapped_f64x2 = wasm_i64x2_shuffle(b_real_imag_f64x2, b_real_imag_f64x2, 1, 0);
|
|
1175
|
+
|
|
1176
|
+
sum_real_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_real_imag_f64x2, sum_real_f64x2);
|
|
1177
|
+
sum_imag_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_swapped_f64x2, sum_imag_f64x2);
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
sum_real_f64x2 = wasm_v128_xor(sum_real_f64x2, sign_flip_i64x2);
|
|
1181
|
+
|
|
1182
|
+
result->real = wasm_f64x2_extract_lane(sum_real_f64x2, 0) + wasm_f64x2_extract_lane(sum_real_f64x2, 1);
|
|
1183
|
+
result->imag = wasm_f64x2_extract_lane(sum_imag_f64x2, 0) + wasm_f64x2_extract_lane(sum_imag_f64x2, 1);
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
NK_PUBLIC void nk_vdot_f64c_v128relaxed(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
1187
|
+
nk_f64c_t *result) {
|
|
1188
|
+
v128_t sum_real_f64x2 = wasm_f64x2_splat(0.0);
|
|
1189
|
+
v128_t sum_imag_f64x2 = wasm_f64x2_splat(0.0);
|
|
1190
|
+
v128_t sign_flip_i64x2 = wasm_i64x2_const(0, 0x8000000000000000ULL);
|
|
1191
|
+
|
|
1192
|
+
nk_size_t idx_pairs = 0;
|
|
1193
|
+
for (; idx_pairs != count_pairs; ++idx_pairs) {
|
|
1194
|
+
v128_t a_real_imag_f64x2 = wasm_v128_load((nk_f64_t const *)(a_pairs + idx_pairs));
|
|
1195
|
+
v128_t b_real_imag_f64x2 = wasm_v128_load((nk_f64_t const *)(b_pairs + idx_pairs));
|
|
1196
|
+
v128_t b_swapped_f64x2 = wasm_i64x2_shuffle(b_real_imag_f64x2, b_real_imag_f64x2, 1, 0);
|
|
1197
|
+
|
|
1198
|
+
sum_real_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_real_imag_f64x2, sum_real_f64x2);
|
|
1199
|
+
sum_imag_f64x2 = wasm_f64x2_relaxed_madd(a_real_imag_f64x2, b_swapped_f64x2, sum_imag_f64x2);
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
sum_imag_f64x2 = wasm_v128_xor(sum_imag_f64x2, sign_flip_i64x2);
|
|
1203
|
+
|
|
1204
|
+
result->real = wasm_f64x2_extract_lane(sum_real_f64x2, 0) + wasm_f64x2_extract_lane(sum_real_f64x2, 1);
|
|
1205
|
+
result->imag = wasm_f64x2_extract_lane(sum_imag_f64x2, 0) + wasm_f64x2_extract_lane(sum_imag_f64x2, 1);
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
typedef struct nk_dot_u1x128_state_v128relaxed_t {
|
|
1209
|
+
v128_t dot_count_u32x4;
|
|
1210
|
+
} nk_dot_u1x128_state_v128relaxed_t;
|
|
1211
|
+
|
|
1212
|
+
NK_INTERNAL void nk_dot_u1x128_init_v128relaxed(nk_dot_u1x128_state_v128relaxed_t *state) {
|
|
1213
|
+
state->dot_count_u32x4 = wasm_u32x4_const(0, 0, 0, 0);
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
NK_INTERNAL void nk_dot_u1x128_update_v128relaxed(nk_dot_u1x128_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
1217
|
+
nk_b128_vec_t b, nk_size_t depth_offset,
|
|
1218
|
+
nk_size_t active_dimensions) {
|
|
1219
|
+
nk_unused_(depth_offset);
|
|
1220
|
+
nk_unused_(active_dimensions);
|
|
1221
|
+
v128_t and_u8x16 = wasm_v128_and(a.v128, b.v128);
|
|
1222
|
+
v128_t popcount_u8x16 = wasm_i8x16_popcnt(and_u8x16);
|
|
1223
|
+
v128_t popcount_u16x8 = wasm_u16x8_extadd_pairwise_u8x16(popcount_u8x16);
|
|
1224
|
+
v128_t popcount_u32x4 = wasm_u32x4_extadd_pairwise_u16x8(popcount_u16x8);
|
|
1225
|
+
state->dot_count_u32x4 = wasm_i32x4_add(state->dot_count_u32x4, popcount_u32x4);
|
|
1226
|
+
}
|
|
1227
|
+
|
|
1228
|
+
NK_INTERNAL void nk_dot_u1x128_finalize_v128relaxed( //
|
|
1229
|
+
nk_dot_u1x128_state_v128relaxed_t const *state_a, nk_dot_u1x128_state_v128relaxed_t const *state_b, //
|
|
1230
|
+
nk_dot_u1x128_state_v128relaxed_t const *state_c, nk_dot_u1x128_state_v128relaxed_t const *state_d, //
|
|
1231
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
1232
|
+
nk_unused_(total_dimensions);
|
|
1233
|
+
v128_t a_u32x4 = state_a->dot_count_u32x4, b_u32x4 = state_b->dot_count_u32x4;
|
|
1234
|
+
v128_t c_u32x4 = state_c->dot_count_u32x4, d_u32x4 = state_d->dot_count_u32x4;
|
|
1235
|
+
// Step 1: interleave pairs
|
|
1236
|
+
v128_t ab_lo_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 0, 4, 1, 5); // a0 b0 a1 b1
|
|
1237
|
+
v128_t ab_hi_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 2, 6, 3, 7); // a2 b2 a3 b3
|
|
1238
|
+
v128_t cd_lo_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 0, 4, 1, 5); // c0 d0 c1 d1
|
|
1239
|
+
v128_t cd_hi_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 2, 6, 3, 7); // c2 d2 c3 d3
|
|
1240
|
+
// Step 2: pairwise add
|
|
1241
|
+
v128_t sum_02_u32x4 = wasm_i32x4_add(ab_lo_u32x4, ab_hi_u32x4); // a02 b02 a13 b13
|
|
1242
|
+
v128_t sum_13_u32x4 = wasm_i32x4_add(cd_lo_u32x4, cd_hi_u32x4); // c02 d02 c13 d13
|
|
1243
|
+
// Step 3: final interleave
|
|
1244
|
+
v128_t even_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 0, 1, 4, 5);
|
|
1245
|
+
v128_t odd_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 2, 3, 6, 7);
|
|
1246
|
+
result->v128 = wasm_i32x4_add(even_u32x4, odd_u32x4); // [sum_a, sum_b, sum_c, sum_d]
|
|
1247
|
+
}
|
|
1248
|
+
|
|
1249
|
+
#if defined(__clang__)
|
|
1250
|
+
#pragma clang attribute pop
|
|
1251
|
+
#endif
|
|
1252
|
+
|
|
1253
|
+
#if defined(__cplusplus)
|
|
1254
|
+
} // extern "C"
|
|
1255
|
+
#endif
|
|
1256
|
+
|
|
1257
|
+
#endif // NK_TARGET_V128RELAXED
|
|
1258
|
+
#endif // NK_DOT_V128RELAXED_H
|