numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,16 +8,16 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_svehalf_instructions ARM SVE+FP16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_f16
|
|
13
|
-
* svld2_f16
|
|
14
|
-
* svmla_f16_x
|
|
15
|
-
* svmls_f16_x
|
|
16
|
-
* svaddv_f16
|
|
17
|
-
* svdup_f16
|
|
18
|
-
* svwhilelt_b16
|
|
19
|
-
* svptrue_b16
|
|
20
|
-
* svcnth
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svld2_f16 LD2H (Z.H, P/Z, [Xn]) 6-8cy @ 1p
|
|
14
|
+
* svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
|
|
15
|
+
* svmls_f16_x FMLS (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
|
|
16
|
+
* svaddv_f16 FADDV (H, P, Z.H) 6cy @ 1p
|
|
17
|
+
* svdup_f16 DUP (Z.H, #imm) 1cy @ 2p
|
|
18
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
|
|
19
|
+
* svptrue_b16 PTRUE (P.H, pattern) 1cy @ 2p
|
|
20
|
+
* svcnth CNTH (Xd) 1cy @ 2p
|
|
21
21
|
*
|
|
22
22
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
23
23
|
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
@@ -51,13 +51,21 @@ NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a_scalars, nk_f16_t const *b_s
|
|
|
51
51
|
nk_size_t idx_scalars = 0;
|
|
52
52
|
svfloat32_t ab_f32x = svdup_f32(0);
|
|
53
53
|
do {
|
|
54
|
-
svbool_t
|
|
55
|
-
svfloat16_t a_f16x = svld1_f16(
|
|
56
|
-
svfloat16_t b_f16x = svld1_f16(
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
54
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_scalars);
|
|
55
|
+
svfloat16_t a_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_scalars) + idx_scalars);
|
|
56
|
+
svfloat16_t b_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_scalars) + idx_scalars);
|
|
57
|
+
nk_size_t remaining = count_scalars - idx_scalars < svcnth() ? count_scalars - idx_scalars : svcnth();
|
|
58
|
+
|
|
59
|
+
// svcvt_f32_f16_x widens only even-indexed f16 elements; svext by 1 shifts odd into even.
|
|
60
|
+
svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
|
|
61
|
+
ab_f32x = svmla_f32_m(pred_even_b32x, ab_f32x, svcvt_f32_f16_x(pred_even_b32x, a_f16x),
|
|
62
|
+
svcvt_f32_f16_x(pred_even_b32x, b_f16x));
|
|
63
|
+
|
|
64
|
+
svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
|
|
65
|
+
ab_f32x = svmla_f32_m(pred_odd_b32x, ab_f32x, svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1)),
|
|
66
|
+
svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1)));
|
|
67
|
+
|
|
68
|
+
idx_scalars += svcnth();
|
|
61
69
|
} while (idx_scalars < count_scalars);
|
|
62
70
|
*result = svaddv_f32(svptrue_b32(), ab_f32x);
|
|
63
71
|
}
|
|
@@ -68,18 +76,36 @@ NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_
|
|
|
68
76
|
svfloat32_t ab_real_f32x = svdup_f32(0);
|
|
69
77
|
svfloat32_t ab_imag_f32x = svdup_f32(0);
|
|
70
78
|
do {
|
|
71
|
-
svbool_t
|
|
72
|
-
svfloat16x2_t
|
|
73
|
-
svfloat16x2_t
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_pairs);
|
|
80
|
+
svfloat16x2_t a_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
|
|
81
|
+
svfloat16x2_t b_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
|
|
82
|
+
svfloat16_t ar_f16x = svget2_f16(a_f16x2x, 0), ai_f16x = svget2_f16(a_f16x2x, 1);
|
|
83
|
+
svfloat16_t br_f16x = svget2_f16(b_f16x2x, 0), bi_f16x = svget2_f16(b_f16x2x, 1);
|
|
84
|
+
nk_size_t remaining = count_pairs - idx_scalars < svcnth() ? count_pairs - idx_scalars : svcnth();
|
|
85
|
+
|
|
86
|
+
// Even-indexed elements of each deinterleaved component
|
|
87
|
+
svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
|
|
88
|
+
svfloat32_t ar_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ar_f16x);
|
|
89
|
+
svfloat32_t ai_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ai_f16x);
|
|
90
|
+
svfloat32_t br_even_f32x = svcvt_f32_f16_x(pred_even_b32x, br_f16x);
|
|
91
|
+
svfloat32_t bi_even_f32x = svcvt_f32_f16_x(pred_even_b32x, bi_f16x);
|
|
92
|
+
ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ar_even_f32x, br_even_f32x);
|
|
93
|
+
ab_real_f32x = svmls_f32_m(pred_even_b32x, ab_real_f32x, ai_even_f32x, bi_even_f32x);
|
|
94
|
+
ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ar_even_f32x, bi_even_f32x);
|
|
95
|
+
ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ai_even_f32x, br_even_f32x);
|
|
96
|
+
|
|
97
|
+
// Odd-indexed elements via svext shift-by-1
|
|
98
|
+
svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
|
|
99
|
+
svfloat32_t ar_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ar_f16x, ar_f16x, 1));
|
|
100
|
+
svfloat32_t ai_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ai_f16x, ai_f16x, 1));
|
|
101
|
+
svfloat32_t br_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(br_f16x, br_f16x, 1));
|
|
102
|
+
svfloat32_t bi_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(bi_f16x, bi_f16x, 1));
|
|
103
|
+
ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ar_odd_f32x, br_odd_f32x);
|
|
104
|
+
ab_real_f32x = svmls_f32_m(pred_odd_b32x, ab_real_f32x, ai_odd_f32x, bi_odd_f32x);
|
|
105
|
+
ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ar_odd_f32x, bi_odd_f32x);
|
|
106
|
+
ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ai_odd_f32x, br_odd_f32x);
|
|
107
|
+
|
|
108
|
+
idx_scalars += svcnth();
|
|
83
109
|
} while (idx_scalars < count_pairs);
|
|
84
110
|
results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
|
|
85
111
|
results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
|
|
@@ -91,18 +117,36 @@ NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b
|
|
|
91
117
|
svfloat32_t ab_real_f32x = svdup_f32(0);
|
|
92
118
|
svfloat32_t ab_imag_f32x = svdup_f32(0);
|
|
93
119
|
do {
|
|
94
|
-
svbool_t
|
|
95
|
-
svfloat16x2_t
|
|
96
|
-
svfloat16x2_t
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
120
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_pairs);
|
|
121
|
+
svfloat16x2_t a_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
|
|
122
|
+
svfloat16x2_t b_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
|
|
123
|
+
svfloat16_t ar_f16x = svget2_f16(a_f16x2x, 0), ai_f16x = svget2_f16(a_f16x2x, 1);
|
|
124
|
+
svfloat16_t br_f16x = svget2_f16(b_f16x2x, 0), bi_f16x = svget2_f16(b_f16x2x, 1);
|
|
125
|
+
nk_size_t remaining = count_pairs - idx_scalars < svcnth() ? count_pairs - idx_scalars : svcnth();
|
|
126
|
+
|
|
127
|
+
// Even-indexed elements
|
|
128
|
+
svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
|
|
129
|
+
svfloat32_t ar_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ar_f16x);
|
|
130
|
+
svfloat32_t ai_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ai_f16x);
|
|
131
|
+
svfloat32_t br_even_f32x = svcvt_f32_f16_x(pred_even_b32x, br_f16x);
|
|
132
|
+
svfloat32_t bi_even_f32x = svcvt_f32_f16_x(pred_even_b32x, bi_f16x);
|
|
133
|
+
ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ar_even_f32x, br_even_f32x);
|
|
134
|
+
ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ai_even_f32x, bi_even_f32x);
|
|
135
|
+
ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ar_even_f32x, bi_even_f32x);
|
|
136
|
+
ab_imag_f32x = svmls_f32_m(pred_even_b32x, ab_imag_f32x, ai_even_f32x, br_even_f32x);
|
|
137
|
+
|
|
138
|
+
// Odd-indexed elements via svext shift-by-1
|
|
139
|
+
svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
|
|
140
|
+
svfloat32_t ar_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ar_f16x, ar_f16x, 1));
|
|
141
|
+
svfloat32_t ai_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ai_f16x, ai_f16x, 1));
|
|
142
|
+
svfloat32_t br_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(br_f16x, br_f16x, 1));
|
|
143
|
+
svfloat32_t bi_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(bi_f16x, bi_f16x, 1));
|
|
144
|
+
ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ar_odd_f32x, br_odd_f32x);
|
|
145
|
+
ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ai_odd_f32x, bi_odd_f32x);
|
|
146
|
+
ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ar_odd_f32x, bi_odd_f32x);
|
|
147
|
+
ab_imag_f32x = svmls_f32_m(pred_odd_b32x, ab_imag_f32x, ai_odd_f32x, br_odd_f32x);
|
|
148
|
+
|
|
149
|
+
idx_scalars += svcnth();
|
|
106
150
|
} while (idx_scalars < count_pairs);
|
|
107
151
|
results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
|
|
108
152
|
results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for SVE SDOT.
|
|
3
|
+
* @file include/numkong/dot/svesdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 3, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_svesdot_instructions ARM SVE+DotProd Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_s8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svld1_u8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
|
|
14
|
+
* svdot_s32 SDOT (Z.S, Z.B, Z.B) 3cy @ 2p
|
|
15
|
+
* svdot_u32 UDOT (Z.S, Z.B, Z.B) 3cy @ 2p
|
|
16
|
+
* svaddv_s32 SADDV (D, P, Z.S) 6cy @ 1p
|
|
17
|
+
* svaddv_u32 UADDV (D, P, Z.S) 6cy @ 1p
|
|
18
|
+
* svdup_s32 DUP (Z.S, #imm) 1cy @ 2p
|
|
19
|
+
* svwhilelt_b8 WHILELT (P.B, Xn, Xm) 2cy @ 1p
|
|
20
|
+
* svcntb CNTB (Xd) 1cy @ 2p
|
|
21
|
+
*
|
|
22
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
23
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
24
|
+
* process more elements per iteration with identical latencies.
|
|
25
|
+
*
|
|
26
|
+
* The SDOT/UDOT instructions fuse four int8 multiplications with int32 accumulation per lane,
|
|
27
|
+
* providing the same 4-way dot product as NEON SDOT but with scalable vector widths.
|
|
28
|
+
* On 256-bit SVE, this processes 32 int8 elements per instruction vs NEON's fixed 16.
|
|
29
|
+
*/
|
|
30
|
+
#ifndef NK_DOT_SVESDOT_H
|
|
31
|
+
#define NK_DOT_SVESDOT_H
|
|
32
|
+
|
|
33
|
+
#if NK_TARGET_ARM_
|
|
34
|
+
#if NK_TARGET_SVESDOT
|
|
35
|
+
|
|
36
|
+
#include "numkong/types.h"
|
|
37
|
+
|
|
38
|
+
#if defined(__cplusplus)
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#if defined(__clang__)
|
|
43
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+dotprod"))), apply_to = function)
|
|
44
|
+
#elif defined(__GNUC__)
|
|
45
|
+
#pragma GCC push_options
|
|
46
|
+
#pragma GCC target("arch=armv8.2-a+sve+dotprod")
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
NK_PUBLIC void nk_dot_i8_svesdot(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
50
|
+
nk_i32_t *result) {
|
|
51
|
+
nk_size_t idx_scalars = 0;
|
|
52
|
+
svint32_t sum_i32x = svdup_s32(0);
|
|
53
|
+
do {
|
|
54
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(idx_scalars, count_scalars);
|
|
55
|
+
svint8_t a_i8x = svld1_s8(predicate_b8x, a_scalars + idx_scalars);
|
|
56
|
+
svint8_t b_i8x = svld1_s8(predicate_b8x, b_scalars + idx_scalars);
|
|
57
|
+
sum_i32x = svdot_s32(sum_i32x, a_i8x, b_i8x);
|
|
58
|
+
idx_scalars += svcntb();
|
|
59
|
+
} while (idx_scalars < count_scalars);
|
|
60
|
+
*result = (nk_i32_t)svaddv_s32(svptrue_b32(), sum_i32x);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
NK_PUBLIC void nk_dot_u8_svesdot(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
64
|
+
nk_u32_t *result) {
|
|
65
|
+
nk_size_t idx_scalars = 0;
|
|
66
|
+
svuint32_t sum_u32x = svdup_u32(0);
|
|
67
|
+
do {
|
|
68
|
+
svbool_t predicate_b8x = svwhilelt_b8_u64(idx_scalars, count_scalars);
|
|
69
|
+
svuint8_t a_u8x = svld1_u8(predicate_b8x, a_scalars + idx_scalars);
|
|
70
|
+
svuint8_t b_u8x = svld1_u8(predicate_b8x, b_scalars + idx_scalars);
|
|
71
|
+
sum_u32x = svdot_u32(sum_u32x, a_u8x, b_u8x);
|
|
72
|
+
idx_scalars += svcntb();
|
|
73
|
+
} while (idx_scalars < count_scalars);
|
|
74
|
+
*result = (nk_u32_t)svaddv_u32(svptrue_b32(), sum_u32x);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
#if defined(__clang__)
|
|
78
|
+
#pragma clang attribute pop
|
|
79
|
+
#elif defined(__GNUC__)
|
|
80
|
+
#pragma GCC pop_options
|
|
81
|
+
#endif
|
|
82
|
+
|
|
83
|
+
#if defined(__cplusplus)
|
|
84
|
+
} // extern "C"
|
|
85
|
+
#endif
|
|
86
|
+
|
|
87
|
+
#endif // NK_TARGET_SVESDOT
|
|
88
|
+
#endif // NK_TARGET_ARM_
|
|
89
|
+
#endif // NK_DOT_SVESDOT_H
|
|
@@ -73,8 +73,8 @@ nk_dot_f32_v128relaxed_cycle:
|
|
|
73
73
|
nk_load_b64_serial_(b_scalars, &b_f32_vec);
|
|
74
74
|
a_scalars += 2, b_scalars += 2, count_scalars -= 2;
|
|
75
75
|
}
|
|
76
|
-
v128_t a_f32x2 =
|
|
77
|
-
v128_t b_f32x2 =
|
|
76
|
+
v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
|
|
77
|
+
v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
|
|
78
78
|
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
79
79
|
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
80
80
|
sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, sum_f64x2);
|
|
@@ -110,24 +110,28 @@ nk_dot_f16_v128relaxed_cycle:
|
|
|
110
110
|
|
|
111
111
|
NK_PUBLIC void nk_dot_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
112
112
|
v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
|
|
113
|
+
v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
|
|
113
114
|
nk_bf16_t const *a_scalars = a, *b_scalars = b;
|
|
114
115
|
nk_size_t count_scalars = n;
|
|
115
|
-
|
|
116
|
+
nk_b128_vec_t a_bf16_vec, b_bf16_vec;
|
|
116
117
|
|
|
117
118
|
nk_dot_bf16_v128relaxed_cycle:
|
|
118
|
-
if (count_scalars <
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
if (count_scalars < 8) {
|
|
120
|
+
nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
|
|
121
|
+
nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
|
|
121
122
|
count_scalars = 0;
|
|
122
123
|
}
|
|
123
124
|
else {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
a_scalars +=
|
|
125
|
+
nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
|
|
126
|
+
nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
|
|
127
|
+
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
127
128
|
}
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
sum_f32x4 = wasm_f32x4_relaxed_madd(
|
|
129
|
+
v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
|
|
130
|
+
v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
|
|
131
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, sum_f32x4);
|
|
132
|
+
v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
|
|
133
|
+
v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
|
|
134
|
+
sum_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, sum_f32x4);
|
|
131
135
|
if (count_scalars) goto nk_dot_bf16_v128relaxed_cycle;
|
|
132
136
|
|
|
133
137
|
*result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
|
|
@@ -274,8 +278,8 @@ NK_PUBLIC void nk_dot_e2m3_v128relaxed(nk_e2m3_t const *a_scalars, nk_e2m3_t con
|
|
|
274
278
|
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
275
279
|
//
|
|
276
280
|
// 32-entry LUT split into two 16-entry halves for wasm_i8x16_relaxed_swizzle (indexes 0-15).
|
|
277
|
-
v128_t
|
|
278
|
-
v128_t
|
|
281
|
+
v128_t lut_low_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
|
|
282
|
+
v128_t lut_high_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
|
|
279
283
|
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
280
284
|
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
281
285
|
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
@@ -304,17 +308,17 @@ nk_dot_e2m3_v128relaxed_cycle:
|
|
|
304
308
|
|
|
305
309
|
// Dual swizzle + bitselect for 32-entry LUT (a)
|
|
306
310
|
v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
|
|
307
|
-
v128_t
|
|
308
|
-
v128_t
|
|
309
|
-
v128_t
|
|
310
|
-
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
311
|
+
v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_shuffle_index_u8x16);
|
|
312
|
+
v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_shuffle_index_u8x16);
|
|
313
|
+
v128_t a_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
314
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_high_select_u8x16);
|
|
311
315
|
|
|
312
316
|
// Dual swizzle + bitselect for 32-entry LUT (b)
|
|
313
317
|
v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
|
|
314
|
-
v128_t
|
|
315
|
-
v128_t
|
|
316
|
-
v128_t
|
|
317
|
-
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
318
|
+
v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_shuffle_index_u8x16);
|
|
319
|
+
v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_shuffle_index_u8x16);
|
|
320
|
+
v128_t b_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
321
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_high_select_u8x16);
|
|
318
322
|
|
|
319
323
|
// Combined sign: (a ^ b) & 0x20 — nonzero means negative product
|
|
320
324
|
// Apply sign to a (relaxed_dot wants i8 × u7: a_signed, b_unsigned)
|
|
@@ -343,12 +347,13 @@ NK_PUBLIC void nk_dot_e3m2_v128relaxed(nk_e3m2_t const *a_scalars, nk_e3m2_t con
|
|
|
343
347
|
// Low-byte LUT entries (magnitude[i] & 0xFF):
|
|
344
348
|
// [0,1,2,3,4,5,6,7,8,10,12,14,16,20,24,28] lower half
|
|
345
349
|
// [32,40,48,56,64,80,96,112,128,160,192,224,0,64,128,192] upper half
|
|
346
|
-
v128_t
|
|
347
|
-
v128_t
|
|
350
|
+
v128_t lut_low_byte_first_u8x16 = wasm_i8x16_const(0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28);
|
|
351
|
+
v128_t lut_low_byte_second_u8x16 = wasm_u8x16_const(32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 0, 64, 128,
|
|
352
|
+
192);
|
|
348
353
|
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
349
354
|
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
350
355
|
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
351
|
-
v128_t
|
|
356
|
+
v128_t high_threshold_u8x16 = wasm_u8x16_splat(28);
|
|
352
357
|
v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
|
|
353
358
|
v128_t sum_i32x4 = wasm_i32x4_splat(0);
|
|
354
359
|
v128_t a_e3m2_u8x16, b_e3m2_u8x16;
|
|
@@ -374,32 +379,34 @@ nk_dot_e3m2_v128relaxed_cycle:
|
|
|
374
379
|
|
|
375
380
|
// Dual swizzle + bitselect for 32-entry low-byte LUT (a)
|
|
376
381
|
v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
|
|
377
|
-
v128_t
|
|
378
|
-
v128_t
|
|
379
|
-
v128_t
|
|
380
|
-
v128_t
|
|
382
|
+
v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_first_u8x16, a_shuffle_index_u8x16);
|
|
383
|
+
v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_second_u8x16, a_shuffle_index_u8x16);
|
|
384
|
+
v128_t a_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
385
|
+
v128_t a_low_byte_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_high_select_u8x16);
|
|
381
386
|
|
|
382
387
|
// High byte is 1 iff magnitude index >= 28 (values 256, 320, 384, 448), else 0
|
|
383
|
-
v128_t
|
|
388
|
+
v128_t a_high_byte_u8x16 = wasm_v128_and(wasm_u8x16_ge(a_magnitude_u8x16, high_threshold_u8x16),
|
|
389
|
+
wasm_u8x16_splat(1));
|
|
384
390
|
|
|
385
391
|
// Dual swizzle + bitselect for 32-entry low-byte LUT (b)
|
|
386
392
|
v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
|
|
387
|
-
v128_t
|
|
388
|
-
v128_t
|
|
389
|
-
v128_t
|
|
390
|
-
v128_t
|
|
393
|
+
v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_first_u8x16, b_shuffle_index_u8x16);
|
|
394
|
+
v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_second_u8x16, b_shuffle_index_u8x16);
|
|
395
|
+
v128_t b_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
|
|
396
|
+
v128_t b_low_byte_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_high_select_u8x16);
|
|
391
397
|
|
|
392
398
|
// High byte is 1 iff magnitude index >= 28
|
|
393
|
-
v128_t
|
|
399
|
+
v128_t b_high_byte_u8x16 = wasm_v128_and(wasm_u8x16_ge(b_magnitude_u8x16, high_threshold_u8x16),
|
|
400
|
+
wasm_u8x16_splat(1));
|
|
394
401
|
|
|
395
402
|
// Combine low and high bytes into i16 via byte interleave shuffle (little-endian: low byte first)
|
|
396
|
-
v128_t a_unsigned_low_i16x8 = wasm_i8x16_shuffle(
|
|
403
|
+
v128_t a_unsigned_low_i16x8 = wasm_i8x16_shuffle(a_low_byte_u8x16, a_high_byte_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
|
|
397
404
|
20, 5, 21, 6, 22, 7, 23);
|
|
398
|
-
v128_t a_unsigned_high_i16x8 = wasm_i8x16_shuffle(
|
|
405
|
+
v128_t a_unsigned_high_i16x8 = wasm_i8x16_shuffle(a_low_byte_u8x16, a_high_byte_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
|
|
399
406
|
12, 28, 13, 29, 14, 30, 15, 31);
|
|
400
|
-
v128_t b_unsigned_low_i16x8 = wasm_i8x16_shuffle(
|
|
407
|
+
v128_t b_unsigned_low_i16x8 = wasm_i8x16_shuffle(b_low_byte_u8x16, b_high_byte_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
|
|
401
408
|
20, 5, 21, 6, 22, 7, 23);
|
|
402
|
-
v128_t b_unsigned_high_i16x8 = wasm_i8x16_shuffle(
|
|
409
|
+
v128_t b_unsigned_high_i16x8 = wasm_i8x16_shuffle(b_low_byte_u8x16, b_high_byte_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
|
|
403
410
|
12, 28, 13, 29, 14, 30, 15, 31);
|
|
404
411
|
|
|
405
412
|
// Combined sign: XOR sign bits, negate only b (saves ~15 ops vs independent negation)
|
|
@@ -497,6 +504,33 @@ NK_INTERNAL void nk_dot_through_f32x4_finalize_v128relaxed_( //
|
|
|
497
504
|
result->f32s[3] = nk_reduce_add_f32x4_v128relaxed_(state_d->sum_f32x4);
|
|
498
505
|
}
|
|
499
506
|
|
|
507
|
+
typedef struct nk_dot_through_f32x4_state_v128relaxed_t_ nk_dot_bf16x8_state_v128relaxed_t;
|
|
508
|
+
|
|
509
|
+
NK_INTERNAL void nk_dot_bf16x8_init_v128relaxed(nk_dot_bf16x8_state_v128relaxed_t *state) {
|
|
510
|
+
nk_dot_through_f32x4_init_v128relaxed_(state);
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
NK_INTERNAL void nk_dot_bf16x8_update_v128relaxed(nk_dot_bf16x8_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
514
|
+
nk_b128_vec_t b, nk_size_t depth_offset,
|
|
515
|
+
nk_size_t active_dimensions) {
|
|
516
|
+
nk_unused_(depth_offset);
|
|
517
|
+
nk_unused_(active_dimensions);
|
|
518
|
+
v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
|
|
519
|
+
v128_t a_even_f32x4 = wasm_i32x4_shl(a.v128, 16);
|
|
520
|
+
v128_t b_even_f32x4 = wasm_i32x4_shl(b.v128, 16);
|
|
521
|
+
state->sum_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, state->sum_f32x4);
|
|
522
|
+
v128_t a_odd_f32x4 = wasm_v128_and(a.v128, mask_high_u32x4);
|
|
523
|
+
v128_t b_odd_f32x4 = wasm_v128_and(b.v128, mask_high_u32x4);
|
|
524
|
+
state->sum_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, state->sum_f32x4);
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
NK_INTERNAL void nk_dot_bf16x8_finalize_v128relaxed( //
|
|
528
|
+
nk_dot_bf16x8_state_v128relaxed_t const *state_a, nk_dot_bf16x8_state_v128relaxed_t const *state_b, //
|
|
529
|
+
nk_dot_bf16x8_state_v128relaxed_t const *state_c, nk_dot_bf16x8_state_v128relaxed_t const *state_d, //
|
|
530
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
531
|
+
nk_dot_through_f32x4_finalize_v128relaxed_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
532
|
+
}
|
|
533
|
+
|
|
500
534
|
typedef struct nk_dot_f32x2_state_v128relaxed_t {
|
|
501
535
|
v128_t sum_f64x2;
|
|
502
536
|
} nk_dot_f32x2_state_v128relaxed_t;
|
|
@@ -509,8 +543,8 @@ NK_INTERNAL void nk_dot_f32x2_update_v128relaxed(nk_dot_f32x2_state_v128relaxed_
|
|
|
509
543
|
nk_b64_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
510
544
|
nk_unused_(depth_offset);
|
|
511
545
|
nk_unused_(active_dimensions);
|
|
512
|
-
v128_t a_f32x2 =
|
|
513
|
-
v128_t b_f32x2 =
|
|
546
|
+
v128_t a_f32x2 = wasm_i64x2_splat(a.u64);
|
|
547
|
+
v128_t b_f32x2 = wasm_i64x2_splat(b.u64);
|
|
514
548
|
v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
|
|
515
549
|
v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
|
|
516
550
|
state->sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, state->sum_f64x2);
|
|
@@ -603,12 +637,12 @@ NK_INTERNAL void nk_dot_i8x16_update_v128relaxed(nk_dot_i8x16_state_v128relaxed_
|
|
|
603
637
|
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
604
638
|
nk_unused_(depth_offset);
|
|
605
639
|
nk_unused_(active_dimensions);
|
|
606
|
-
// Bit-split: b =
|
|
607
|
-
// So a·b = a·
|
|
608
|
-
v128_t
|
|
609
|
-
v128_t
|
|
610
|
-
state->product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128,
|
|
611
|
-
state->negative_sum_a_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128,
|
|
640
|
+
// Bit-split: b = b_low + (-128)·b_high where b_low = b & 0x7F ∈ [0,127], b_high = b >> 7 ∈ {0,1}
|
|
641
|
+
// So a·b = a·b_low − 128·a·b_high, both operands fit i7 for relaxed_dot
|
|
642
|
+
v128_t b_low_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
|
|
643
|
+
v128_t b_high_u8x16 = wasm_u8x16_shr(b.v128, 7);
|
|
644
|
+
state->product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_low_u8x16, state->product_sum_i32x4);
|
|
645
|
+
state->negative_sum_a_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_high_u8x16,
|
|
612
646
|
state->negative_sum_a_i32x4);
|
|
613
647
|
}
|
|
614
648
|
|
|
@@ -629,28 +663,29 @@ NK_INTERNAL void nk_dot_i8x16_finalize_v128relaxed(
|
|
|
629
663
|
}
|
|
630
664
|
|
|
631
665
|
typedef struct nk_dot_u8x16_state_v128relaxed_t {
|
|
632
|
-
v128_t
|
|
633
|
-
v128_t
|
|
666
|
+
v128_t product_low_i32x4; // relaxed_dot(a_signed, b_low) accumulator
|
|
667
|
+
v128_t product_high_i32x4; // relaxed_dot(a_signed, b_high) accumulator
|
|
634
668
|
} nk_dot_u8x16_state_v128relaxed_t;
|
|
635
669
|
|
|
636
670
|
NK_INTERNAL void nk_dot_u8x16_init_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state) {
|
|
637
|
-
state->
|
|
638
|
-
state->
|
|
671
|
+
state->product_low_i32x4 = wasm_i32x4_splat(0);
|
|
672
|
+
state->product_high_i32x4 = wasm_i32x4_splat(0);
|
|
639
673
|
}
|
|
640
674
|
|
|
641
675
|
NK_INTERNAL void nk_dot_u8x16_update_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state, nk_b128_vec_t a,
|
|
642
676
|
nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
643
677
|
nk_unused_(depth_offset);
|
|
644
678
|
nk_unused_(active_dimensions);
|
|
645
|
-
// Bit-split b: b =
|
|
646
|
-
// Σ a·b = Σ(a_signed+128)·(b_lo+128·
|
|
679
|
+
// Bit-split b: b = b_low + 128·b_high, with a_signed = a ^ 0x80 = a - 128 (reinterpret u8 as i8)
|
|
680
|
+
// Σ a·b = Σ(a_signed+128)·(b_lo+128·b_high) = relaxed_dot(a_signed,b_low) + 128·relaxed_dot(a_signed,b_high) +
|
|
681
|
+
// 128·Σb
|
|
647
682
|
v128_t a_signed_i8x16 = wasm_v128_xor(a.v128, wasm_i8x16_splat((signed char)0x80));
|
|
648
|
-
v128_t
|
|
649
|
-
v128_t
|
|
650
|
-
state->
|
|
651
|
-
|
|
652
|
-
state->
|
|
653
|
-
|
|
683
|
+
v128_t b_low_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
|
|
684
|
+
v128_t b_high_u8x16 = wasm_u8x16_shr(b.v128, 7);
|
|
685
|
+
state->product_low_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_low_u8x16,
|
|
686
|
+
state->product_low_i32x4);
|
|
687
|
+
state->product_high_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_high_u8x16,
|
|
688
|
+
state->product_high_i32x4);
|
|
654
689
|
}
|
|
655
690
|
|
|
656
691
|
NK_INTERNAL void nk_dot_u8x16_finalize_v128relaxed( //
|
|
@@ -659,17 +694,17 @@ NK_INTERNAL void nk_dot_u8x16_finalize_v128relaxed(
|
|
|
659
694
|
nk_size_t total_dimensions, nk_u32_t a_sum, nk_b128_vec_t b_sums, nk_b128_vec_t *result) {
|
|
660
695
|
nk_unused_(a_sum);
|
|
661
696
|
// Σ a·b = reduce(lo) + 128·reduce(hi) + 128·Σb
|
|
662
|
-
result->u32s[0] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_a->
|
|
663
|
-
128 * nk_reduce_add_i32x4_v128relaxed_(state_a->
|
|
697
|
+
result->u32s[0] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_a->product_low_i32x4) +
|
|
698
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_a->product_high_i32x4) +
|
|
664
699
|
128 * (nk_i32_t)b_sums.u32s[0]);
|
|
665
|
-
result->u32s[1] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_b->
|
|
666
|
-
128 * nk_reduce_add_i32x4_v128relaxed_(state_b->
|
|
700
|
+
result->u32s[1] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_b->product_low_i32x4) +
|
|
701
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_b->product_high_i32x4) +
|
|
667
702
|
128 * (nk_i32_t)b_sums.u32s[1]);
|
|
668
|
-
result->u32s[2] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_c->
|
|
669
|
-
128 * nk_reduce_add_i32x4_v128relaxed_(state_c->
|
|
703
|
+
result->u32s[2] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_c->product_low_i32x4) +
|
|
704
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_c->product_high_i32x4) +
|
|
670
705
|
128 * (nk_i32_t)b_sums.u32s[2]);
|
|
671
|
-
result->u32s[3] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_d->
|
|
672
|
-
128 * nk_reduce_add_i32x4_v128relaxed_(state_d->
|
|
706
|
+
result->u32s[3] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_d->product_low_i32x4) +
|
|
707
|
+
128 * nk_reduce_add_i32x4_v128relaxed_(state_d->product_high_i32x4) +
|
|
673
708
|
128 * (nk_i32_t)b_sums.u32s[3]);
|
|
674
709
|
}
|
|
675
710
|
|
|
@@ -706,8 +741,8 @@ NK_INTERNAL void nk_dot_e2m3x16_update_v128relaxed(nk_dot_e2m3x16_state_v128rela
|
|
|
706
741
|
nk_unused_(depth_offset);
|
|
707
742
|
nk_unused_(active_dimensions);
|
|
708
743
|
// Same LUT-based approach as 1:1 dot, accumulating into state
|
|
709
|
-
v128_t
|
|
710
|
-
v128_t
|
|
744
|
+
v128_t lut_low_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
|
|
745
|
+
v128_t lut_high_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
|
|
711
746
|
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
712
747
|
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
713
748
|
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
@@ -719,17 +754,17 @@ NK_INTERNAL void nk_dot_e2m3x16_update_v128relaxed(nk_dot_e2m3x16_state_v128rela
|
|
|
719
754
|
|
|
720
755
|
// Dual swizzle LUT for a
|
|
721
756
|
v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
|
|
722
|
-
v128_t
|
|
723
|
-
v128_t
|
|
757
|
+
v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_idx_u8x16);
|
|
758
|
+
v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_idx_u8x16);
|
|
724
759
|
v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
725
|
-
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
760
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_sel_u8x16);
|
|
726
761
|
|
|
727
762
|
// Dual swizzle LUT for b
|
|
728
763
|
v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
|
|
729
|
-
v128_t
|
|
730
|
-
v128_t
|
|
764
|
+
v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_idx_u8x16);
|
|
765
|
+
v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_idx_u8x16);
|
|
731
766
|
v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
732
|
-
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
767
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_sel_u8x16);
|
|
733
768
|
|
|
734
769
|
// Combined sign → apply to a (relaxed_dot wants i8 × u7)
|
|
735
770
|
v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
|
|
@@ -770,8 +805,8 @@ NK_INTERNAL void nk_dot_e3m2x16_update_v128relaxed(nk_dot_e3m2x16_state_v128rela
|
|
|
770
805
|
// ×4 scaled LUT — all values ≤ 112, fits u7 for relaxed_dot
|
|
771
806
|
// Indices 0-11 rounded to nearest integer (max error ±0.5 in ×4 domain = ±0.125 in value)
|
|
772
807
|
// Indices 12-31 exact
|
|
773
|
-
v128_t
|
|
774
|
-
v128_t
|
|
808
|
+
v128_t lut_low_u8x16 = wasm_i8x16_const(0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 6, 7);
|
|
809
|
+
v128_t lut_high_u8x16 = wasm_i8x16_const(8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112);
|
|
775
810
|
v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
|
|
776
811
|
v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
|
|
777
812
|
v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
|
|
@@ -782,17 +817,17 @@ NK_INTERNAL void nk_dot_e3m2x16_update_v128relaxed(nk_dot_e3m2x16_state_v128rela
|
|
|
782
817
|
|
|
783
818
|
// Dual swizzle LUT for a
|
|
784
819
|
v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
|
|
785
|
-
v128_t
|
|
786
|
-
v128_t
|
|
820
|
+
v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_idx_u8x16);
|
|
821
|
+
v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_idx_u8x16);
|
|
787
822
|
v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
788
|
-
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
823
|
+
v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_sel_u8x16);
|
|
789
824
|
|
|
790
825
|
// Dual swizzle LUT for b
|
|
791
826
|
v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
|
|
792
|
-
v128_t
|
|
793
|
-
v128_t
|
|
827
|
+
v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_idx_u8x16);
|
|
828
|
+
v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_idx_u8x16);
|
|
794
829
|
v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
|
|
795
|
-
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(
|
|
830
|
+
v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_sel_u8x16);
|
|
796
831
|
|
|
797
832
|
// Combined sign → apply to a (relaxed_dot wants i8 × u7)
|
|
798
833
|
v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
|
|
@@ -1233,13 +1268,13 @@ NK_INTERNAL void nk_dot_u1x128_finalize_v128relaxed(
|
|
|
1233
1268
|
v128_t a_u32x4 = state_a->dot_count_u32x4, b_u32x4 = state_b->dot_count_u32x4;
|
|
1234
1269
|
v128_t c_u32x4 = state_c->dot_count_u32x4, d_u32x4 = state_d->dot_count_u32x4;
|
|
1235
1270
|
// Step 1: interleave pairs
|
|
1236
|
-
v128_t
|
|
1237
|
-
v128_t
|
|
1238
|
-
v128_t
|
|
1239
|
-
v128_t
|
|
1271
|
+
v128_t ab_low_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 0, 4, 1, 5); // a0 b0 a1 b1
|
|
1272
|
+
v128_t ab_high_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 2, 6, 3, 7); // a2 b2 a3 b3
|
|
1273
|
+
v128_t cd_low_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 0, 4, 1, 5); // c0 d0 c1 d1
|
|
1274
|
+
v128_t cd_high_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 2, 6, 3, 7); // c2 d2 c3 d3
|
|
1240
1275
|
// Step 2: pairwise add
|
|
1241
|
-
v128_t sum_02_u32x4 = wasm_i32x4_add(
|
|
1242
|
-
v128_t sum_13_u32x4 = wasm_i32x4_add(
|
|
1276
|
+
v128_t sum_02_u32x4 = wasm_i32x4_add(ab_low_u32x4, ab_high_u32x4); // a02 b02 a13 b13
|
|
1277
|
+
v128_t sum_13_u32x4 = wasm_i32x4_add(cd_low_u32x4, cd_high_u32x4); // c02 d02 c13 d13
|
|
1243
1278
|
// Step 3: final interleave
|
|
1244
1279
|
v128_t even_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 0, 1, 4, 5);
|
|
1245
1280
|
v128_t odd_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 2, 3, 6, 7);
|