numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for LoongArch LASX (256-bit).
|
|
3
|
+
* @file include/numkong/dot/loongsonasx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_loongsonasx_instructions Key LASX Dot Product Instructions
|
|
10
|
+
*
|
|
11
|
+
* LASX provides 256-bit SIMD operations using __m256i as the universal vector type.
|
|
12
|
+
* All intrinsics are prefixed with __lasx_. Float operations reinterpret __m256i as
|
|
13
|
+
* f32x8 or f64x4. Integer widening multiply-accumulate chains handle i8/u8 dot products.
|
|
14
|
+
*
|
|
15
|
+
* For F32 dot products, upcasting to F64 and downcasting back is faster than stable
|
|
16
|
+
* summation algorithms. For F64 we use the Dot2 algorithm (Ogita-Rump-Oishi, 2005)
|
|
17
|
+
* for compensated accumulation via TwoSum/TwoProd.
|
|
18
|
+
*
|
|
19
|
+
* @section dot_loongsonasx_stateful Stateful Streaming Logic
|
|
20
|
+
*
|
|
21
|
+
* - nk_dot_f64x4 state with Dot2 stable dot-products,
|
|
22
|
+
* - nk_dot_f32x4 state with double-precision numerics,
|
|
23
|
+
* - nk_dot_through_i32 state for 8-bit signed and unsigned integer inputs.
|
|
24
|
+
*/
|
|
25
|
+
#ifndef NK_DOT_LOONGSONASX_H
|
|
26
|
+
#define NK_DOT_LOONGSONASX_H
|
|
27
|
+
|
|
28
|
+
#if NK_TARGET_LOONGARCH_
|
|
29
|
+
#if NK_TARGET_LOONGSONASX
|
|
30
|
+
|
|
31
|
+
#include "numkong/types.h"
|
|
32
|
+
#include "numkong/dot/serial.h"
|
|
33
|
+
#include "numkong/cast/loongsonasx.h" // `nk_bf16x8_to_f32x8_loongsonasx_`
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#pragma region Horizontal Reduction Helpers
|
|
40
|
+
|
|
41
|
+
/** @brief Horizontal sum of 4 f64 lanes in a 256-bit LASX register. */
|
|
42
|
+
NK_INTERNAL nk_f64_t nk_reduce_add_f64x4_loongsonasx_(__m256d sum_f64x4) {
|
|
43
|
+
// Add high 128-bit lane to low 128-bit lane
|
|
44
|
+
__m256d high_f64x4 = (__m256d)__lasx_xvpermi_q((__m256i)sum_f64x4, (__m256i)sum_f64x4, 0x11);
|
|
45
|
+
__m256d sum_f64x2 = __lasx_xvfadd_d(sum_f64x4, high_f64x4);
|
|
46
|
+
// Swap lanes 0↔1, add to reduce to 1 value, then extract
|
|
47
|
+
__m256d swapped_f64x2 = (__m256d)__lasx_xvshuf4i_d((__m256i)sum_f64x2, (__m256i)sum_f64x2, 0b0001);
|
|
48
|
+
__m256d reduced_f64x2 = __lasx_xvfadd_d(sum_f64x2, swapped_f64x2);
|
|
49
|
+
nk_fui64_t c;
|
|
50
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)reduced_f64x2, 0);
|
|
51
|
+
return c.f;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/** @brief Horizontal sum of 8 i32 lanes in a 256-bit LASX register. */
|
|
55
|
+
NK_INTERNAL nk_i32_t nk_reduce_add_i32x8_loongsonasx_(__m256i sum_i32x8) {
|
|
56
|
+
__m256i high_i32x8 = __lasx_xvpermi_q(sum_i32x8, sum_i32x8, 0x11);
|
|
57
|
+
__m256i sum_i32x4 = __lasx_xvadd_w(sum_i32x8, high_i32x8);
|
|
58
|
+
// Pairwise widen i32 → i64, then extract and add
|
|
59
|
+
__m256i sum_i64x2 = __lasx_xvhaddw_d_w(sum_i32x4, sum_i32x4);
|
|
60
|
+
return (nk_i32_t)(__lasx_xvpickve2gr_d(sum_i64x2, 0) + __lasx_xvpickve2gr_d(sum_i64x2, 1));
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/** @brief Compensated horizontal sum of 4 f64 lanes via TwoSum tree reduction.
|
|
64
|
+
* @sa nk_reduce_sum_f64_serial_ for the serial equivalent
|
|
65
|
+
*/
|
|
66
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x4_loongsonasx_(__m256d sum_f64x4, __m256d compensation_f64x4) {
|
|
67
|
+
// Stage 0: TwoSum merge of sum + compensation (4-wide, parallel)
|
|
68
|
+
__m256d tentative_sum_f64x4 = __lasx_xvfadd_d(sum_f64x4, compensation_f64x4);
|
|
69
|
+
__m256d virtual_addend_f64x4 = __lasx_xvfsub_d(tentative_sum_f64x4, sum_f64x4);
|
|
70
|
+
__m256d rounding_error_f64x4 = __lasx_xvfadd_d(
|
|
71
|
+
__lasx_xvfsub_d(sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
72
|
+
__lasx_xvfsub_d(compensation_f64x4, virtual_addend_f64x4));
|
|
73
|
+
|
|
74
|
+
// Stage 1: TwoSum halving 4 → 2 by adding high 128-bit lane to low 128-bit lane
|
|
75
|
+
__m256d upper_sum_f64x4 = (__m256d)__lasx_xvpermi_q((__m256i)tentative_sum_f64x4, (__m256i)tentative_sum_f64x4,
|
|
76
|
+
0x11);
|
|
77
|
+
__m256d lower_sum_f64x4 = tentative_sum_f64x4; // low 128 bits are already there
|
|
78
|
+
__m256d tentative_sum_f64x2 = __lasx_xvfadd_d(lower_sum_f64x4, upper_sum_f64x4);
|
|
79
|
+
__m256d virtual_addend_f64x2 = __lasx_xvfsub_d(tentative_sum_f64x2, lower_sum_f64x4);
|
|
80
|
+
__m256d rounding_error_f64x2 = __lasx_xvfadd_d(
|
|
81
|
+
__lasx_xvfsub_d(lower_sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x2, virtual_addend_f64x2)),
|
|
82
|
+
__lasx_xvfsub_d(upper_sum_f64x4, virtual_addend_f64x2));
|
|
83
|
+
// Accumulate errors: stage 0 errors (halved) + stage 1 rounding error
|
|
84
|
+
__m256d upper_error_f64x4 = (__m256d)__lasx_xvpermi_q((__m256i)rounding_error_f64x4, (__m256i)rounding_error_f64x4,
|
|
85
|
+
0x11);
|
|
86
|
+
__m256d lower_error_f64x4 = rounding_error_f64x4; // low 128 bits are already there
|
|
87
|
+
__m256d accumulated_error_f64x2 = __lasx_xvfadd_d(__lasx_xvfadd_d(lower_error_f64x4, upper_error_f64x4),
|
|
88
|
+
rounding_error_f64x2);
|
|
89
|
+
|
|
90
|
+
// Stage 2: Scalar TwoSum 2 → 1
|
|
91
|
+
nk_fui64_t c;
|
|
92
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)tentative_sum_f64x2, 0);
|
|
93
|
+
nk_f64_t sum_low = c.f;
|
|
94
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)tentative_sum_f64x2, 1);
|
|
95
|
+
nk_f64_t sum_high = c.f;
|
|
96
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)accumulated_error_f64x2, 0);
|
|
97
|
+
nk_f64_t error_low = c.f;
|
|
98
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)accumulated_error_f64x2, 1);
|
|
99
|
+
nk_f64_t error_high = c.f;
|
|
100
|
+
nk_f64_t tentative_sum = sum_low + sum_high;
|
|
101
|
+
nk_f64_t virtual_addend = tentative_sum - sum_low;
|
|
102
|
+
nk_f64_t rounding_error = (sum_low - (tentative_sum - virtual_addend)) + (sum_high - virtual_addend);
|
|
103
|
+
return tentative_sum + (error_low + error_high + rounding_error);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
#pragma endregion Horizontal Reduction Helpers
|
|
107
|
+
|
|
108
|
+
#pragma region F32 and F64 Floats
|
|
109
|
+
|
|
110
|
+
NK_PUBLIC void nk_dot_f32_loongsonasx(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
111
|
+
nk_f64_t *result) {
|
|
112
|
+
// LASX is 256-bit = 8 × f32. Load 8 f32, split into low/high 4, widen each to f64, FMA in f64.
|
|
113
|
+
__m256d sum_low_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0); // 4 f64 accumulators (from low 4 f32)
|
|
114
|
+
__m256d sum_high_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0); // 4 f64 accumulators (from high 4 f32)
|
|
115
|
+
nk_size_t idx_scalars = 0;
|
|
116
|
+
for (; idx_scalars + 8 <= count_scalars; idx_scalars += 8) {
|
|
117
|
+
__m256i a_f32x8 = __lasx_xvld(a_scalars + idx_scalars, 0);
|
|
118
|
+
__m256i b_f32x8 = __lasx_xvld(b_scalars + idx_scalars, 0);
|
|
119
|
+
// Widen low 4 f32 → f64
|
|
120
|
+
__m256d a_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)a_f32x8);
|
|
121
|
+
__m256d b_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)b_f32x8);
|
|
122
|
+
// Widen high 4 f32 → f64
|
|
123
|
+
__m256d a_high_f64x4 = __lasx_xvfcvth_d_s((__m256)a_f32x8);
|
|
124
|
+
__m256d b_high_f64x4 = __lasx_xvfcvth_d_s((__m256)b_f32x8);
|
|
125
|
+
// FMA in f64
|
|
126
|
+
sum_low_f64x4 = __lasx_xvfmadd_d(a_low_f64x4, b_low_f64x4, sum_low_f64x4);
|
|
127
|
+
sum_high_f64x4 = __lasx_xvfmadd_d(a_high_f64x4, b_high_f64x4, sum_high_f64x4);
|
|
128
|
+
}
|
|
129
|
+
__m256d combined_f64x4 = __lasx_xvfadd_d(sum_low_f64x4, sum_high_f64x4);
|
|
130
|
+
nk_f64_t sum = nk_reduce_add_f64x4_loongsonasx_(combined_f64x4);
|
|
131
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_f64_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
132
|
+
*result = sum;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
NK_PUBLIC void nk_dot_f64_loongsonasx(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
136
|
+
nk_f64_t *result) {
|
|
137
|
+
// Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated dot product
|
|
138
|
+
__m256d sum_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
139
|
+
__m256d compensation_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
140
|
+
nk_size_t idx_scalars = 0;
|
|
141
|
+
for (; idx_scalars + 4 <= count_scalars; idx_scalars += 4) {
|
|
142
|
+
__m256d a_f64x4 = (__m256d)__lasx_xvld(a_scalars + idx_scalars, 0);
|
|
143
|
+
__m256d b_f64x4 = (__m256d)__lasx_xvld(b_scalars + idx_scalars, 0);
|
|
144
|
+
|
|
145
|
+
// TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
|
|
146
|
+
__m256d product_f64x4 = __lasx_xvfmul_d(a_f64x4, b_f64x4);
|
|
147
|
+
__m256d product_error_f64x4 = __lasx_xvfmsub_d(a_f64x4, b_f64x4, product_f64x4);
|
|
148
|
+
|
|
149
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
150
|
+
__m256d tentative_sum_f64x4 = __lasx_xvfadd_d(sum_f64x4, product_f64x4);
|
|
151
|
+
__m256d virtual_addend_f64x4 = __lasx_xvfsub_d(tentative_sum_f64x4, sum_f64x4);
|
|
152
|
+
__m256d sum_error_f64x4 = __lasx_xvfadd_d(
|
|
153
|
+
__lasx_xvfsub_d(sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
154
|
+
__lasx_xvfsub_d(product_f64x4, virtual_addend_f64x4));
|
|
155
|
+
|
|
156
|
+
// Update: sum = t, compensation += q + r
|
|
157
|
+
sum_f64x4 = tentative_sum_f64x4;
|
|
158
|
+
compensation_f64x4 = __lasx_xvfadd_d(compensation_f64x4, __lasx_xvfadd_d(sum_error_f64x4, product_error_f64x4));
|
|
159
|
+
}
|
|
160
|
+
// Scalar tail
|
|
161
|
+
nk_f64_t sum = nk_dot_stable_sum_f64x4_loongsonasx_(sum_f64x4, compensation_f64x4);
|
|
162
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_f64_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
163
|
+
*result = sum;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
NK_PUBLIC void nk_dot_i8_loongsonasx(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
167
|
+
nk_i32_t *result) {
|
|
168
|
+
__m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
169
|
+
nk_size_t idx_scalars = 0;
|
|
170
|
+
for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
|
|
171
|
+
__m256i a_i8x32 = __lasx_xvld(a_scalars + idx_scalars, 0);
|
|
172
|
+
__m256i b_i8x32 = __lasx_xvld(b_scalars + idx_scalars, 0);
|
|
173
|
+
// Widening multiply i8 × i8 → i16 (even and odd elements separately)
|
|
174
|
+
__m256i acc_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
175
|
+
acc_i16x16 = __lasx_xvmaddwev_h_b(acc_i16x16, a_i8x32, b_i8x32);
|
|
176
|
+
acc_i16x16 = __lasx_xvmaddwod_h_b(acc_i16x16, a_i8x32, b_i8x32);
|
|
177
|
+
// Horizontal pairwise i16 → i32, then accumulate
|
|
178
|
+
__m256i widened_i32x8 = __lasx_xvhaddw_w_h(acc_i16x16, acc_i16x16);
|
|
179
|
+
sum_i32x8 = __lasx_xvadd_w(sum_i32x8, widened_i32x8);
|
|
180
|
+
}
|
|
181
|
+
nk_i32_t sum = nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
|
|
182
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_i32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
183
|
+
*result = sum;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
NK_PUBLIC void nk_dot_u8_loongsonasx(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
|
|
187
|
+
nk_u32_t *result) {
|
|
188
|
+
__m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
189
|
+
nk_size_t idx_scalars = 0;
|
|
190
|
+
for (; idx_scalars + 32 <= count_scalars; idx_scalars += 32) {
|
|
191
|
+
__m256i a_u8x32 = __lasx_xvld(a_scalars + idx_scalars, 0);
|
|
192
|
+
__m256i b_u8x32 = __lasx_xvld(b_scalars + idx_scalars, 0);
|
|
193
|
+
// Unsigned widening multiply u8 × u8 → u16 (even and odd elements separately)
|
|
194
|
+
__m256i acc_u16x16 = __lasx_xvreplgr2vr_h(0);
|
|
195
|
+
acc_u16x16 = __lasx_xvmaddwev_h_bu(acc_u16x16, a_u8x32, b_u8x32);
|
|
196
|
+
acc_u16x16 = __lasx_xvmaddwod_h_bu(acc_u16x16, a_u8x32, b_u8x32);
|
|
197
|
+
// Unsigned horizontal pairwise u16 → u32, then accumulate
|
|
198
|
+
__m256i widened_u32x8 = __lasx_xvhaddw_wu_hu(acc_u16x16, acc_u16x16);
|
|
199
|
+
sum_i32x8 = __lasx_xvadd_w(sum_i32x8, widened_u32x8);
|
|
200
|
+
}
|
|
201
|
+
nk_u32_t sum = (nk_u32_t)nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
|
|
202
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) sum += (nk_u32_t)a_scalars[idx_scalars] * b_scalars[idx_scalars];
|
|
203
|
+
*result = sum;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
NK_PUBLIC void nk_dot_bf16_loongsonasx(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
207
|
+
nk_f32_t *result) {
|
|
208
|
+
__m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
209
|
+
__m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
|
|
210
|
+
nk_size_t idx_scalars = 0;
|
|
211
|
+
for (; idx_scalars + 16 <= count_scalars; idx_scalars += 16) {
|
|
212
|
+
__m256i a_bf16x16 = __lasx_xvld(a_scalars + idx_scalars, 0);
|
|
213
|
+
__m256i b_bf16x16 = __lasx_xvld(b_scalars + idx_scalars, 0);
|
|
214
|
+
__m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a_bf16x16, 16);
|
|
215
|
+
__m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b_bf16x16, 16);
|
|
216
|
+
sum_f32x8 = __lasx_xvfmadd_s(a_even_f32x8, b_even_f32x8, sum_f32x8);
|
|
217
|
+
__m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a_bf16x16, mask_high_u32x8);
|
|
218
|
+
__m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b_bf16x16, mask_high_u32x8);
|
|
219
|
+
sum_f32x8 = __lasx_xvfmadd_s(a_odd_f32x8, b_odd_f32x8, sum_f32x8);
|
|
220
|
+
}
|
|
221
|
+
// Horizontal reduce 8 × f32 → 1 × f32
|
|
222
|
+
__m256 high_f32x4 = (__m256)__lasx_xvpermi_q((__m256i)sum_f32x8, (__m256i)sum_f32x8, 0x11);
|
|
223
|
+
__m256 sum_f32x4 = __lasx_xvfadd_s(sum_f32x8, high_f32x4);
|
|
224
|
+
__m256 swapped_f32x4 = (__m256)__lasx_xvshuf4i_w((__m256i)sum_f32x4, 0b01001110);
|
|
225
|
+
__m256 reduced_f32x4 = __lasx_xvfadd_s(sum_f32x4, swapped_f32x4);
|
|
226
|
+
__m256 swapped_f32x2 = (__m256)__lasx_xvshuf4i_w((__m256i)reduced_f32x4, 0b10110001);
|
|
227
|
+
__m256 reduced_f32x2 = __lasx_xvfadd_s(reduced_f32x4, swapped_f32x2);
|
|
228
|
+
nk_fui32_t c;
|
|
229
|
+
c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)reduced_f32x2, 0);
|
|
230
|
+
nk_f32_t sum = c.f;
|
|
231
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) {
|
|
232
|
+
nk_f32_t a_val, b_val;
|
|
233
|
+
nk_bf16_to_f32_serial(&a_scalars[idx_scalars], &a_val);
|
|
234
|
+
nk_bf16_to_f32_serial(&b_scalars[idx_scalars], &b_val);
|
|
235
|
+
sum += a_val * b_val;
|
|
236
|
+
}
|
|
237
|
+
*result = sum;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
typedef struct nk_dot_f64x4_state_loongsonasx_t {
|
|
241
|
+
__m256i sum_f64x4;
|
|
242
|
+
__m256i compensation_f64x4; // Error accumulator for Dot2
|
|
243
|
+
} nk_dot_f64x4_state_loongsonasx_t;
|
|
244
|
+
|
|
245
|
+
NK_INTERNAL void nk_dot_f64x4_init_loongsonasx(nk_dot_f64x4_state_loongsonasx_t *state) {
|
|
246
|
+
state->sum_f64x4 = __lasx_xvreplgr2vr_d(0);
|
|
247
|
+
state->compensation_f64x4 = __lasx_xvreplgr2vr_d(0);
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
NK_INTERNAL void nk_dot_f64x4_update_loongsonasx(nk_dot_f64x4_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
251
|
+
nk_b256_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
252
|
+
nk_unused_(depth_offset);
|
|
253
|
+
nk_unused_(active_dimensions);
|
|
254
|
+
__m256d sum_f64x4 = (__m256d)state->sum_f64x4;
|
|
255
|
+
__m256d compensation_f64x4 = (__m256d)state->compensation_f64x4;
|
|
256
|
+
__m256d a_f64x4 = a.ymm_pd;
|
|
257
|
+
__m256d b_f64x4 = b.ymm_pd;
|
|
258
|
+
|
|
259
|
+
// TwoProd: h = a * b, r = fma(a, b, -h) captures the rounding error
|
|
260
|
+
__m256d product_f64x4 = __lasx_xvfmul_d(a_f64x4, b_f64x4);
|
|
261
|
+
__m256d product_error_f64x4 = __lasx_xvfmsub_d(a_f64x4, b_f64x4, product_f64x4);
|
|
262
|
+
|
|
263
|
+
// TwoSum: (t, q) = TwoSum(sum, h) where t = sum + h rounded, q = error
|
|
264
|
+
__m256d tentative_sum_f64x4 = __lasx_xvfadd_d(sum_f64x4, product_f64x4);
|
|
265
|
+
__m256d virtual_addend_f64x4 = __lasx_xvfsub_d(tentative_sum_f64x4, sum_f64x4);
|
|
266
|
+
__m256d sum_error_f64x4 = __lasx_xvfadd_d(
|
|
267
|
+
__lasx_xvfsub_d(sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
268
|
+
__lasx_xvfsub_d(product_f64x4, virtual_addend_f64x4));
|
|
269
|
+
|
|
270
|
+
// Update: sum = t, compensation += q + r
|
|
271
|
+
state->sum_f64x4 = (__m256i)tentative_sum_f64x4;
|
|
272
|
+
state->compensation_f64x4 = (__m256i)__lasx_xvfadd_d(compensation_f64x4,
|
|
273
|
+
__lasx_xvfadd_d(sum_error_f64x4, product_error_f64x4));
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
NK_INTERNAL void nk_dot_f64x4_finalize_loongsonasx( //
|
|
277
|
+
nk_dot_f64x4_state_loongsonasx_t const *state_a, nk_dot_f64x4_state_loongsonasx_t const *state_b, //
|
|
278
|
+
nk_dot_f64x4_state_loongsonasx_t const *state_c, nk_dot_f64x4_state_loongsonasx_t const *state_d, //
|
|
279
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
280
|
+
nk_unused_(total_dimensions);
|
|
281
|
+
// Compensated horizontal reduction preserving Dot2 error tracking per state
|
|
282
|
+
result->f64s[0] = nk_dot_stable_sum_f64x4_loongsonasx_((__m256d)state_a->sum_f64x4,
|
|
283
|
+
(__m256d)state_a->compensation_f64x4);
|
|
284
|
+
result->f64s[1] = nk_dot_stable_sum_f64x4_loongsonasx_((__m256d)state_b->sum_f64x4,
|
|
285
|
+
(__m256d)state_b->compensation_f64x4);
|
|
286
|
+
result->f64s[2] = nk_dot_stable_sum_f64x4_loongsonasx_((__m256d)state_c->sum_f64x4,
|
|
287
|
+
(__m256d)state_c->compensation_f64x4);
|
|
288
|
+
result->f64s[3] = nk_dot_stable_sum_f64x4_loongsonasx_((__m256d)state_d->sum_f64x4,
|
|
289
|
+
(__m256d)state_d->compensation_f64x4);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
typedef struct nk_dot_f32x8_state_loongsonasx_t {
|
|
293
|
+
__m256i sum_f64x4;
|
|
294
|
+
} nk_dot_f32x8_state_loongsonasx_t;
|
|
295
|
+
|
|
296
|
+
NK_INTERNAL void nk_dot_f32x8_init_loongsonasx(nk_dot_f32x8_state_loongsonasx_t *state) {
|
|
297
|
+
state->sum_f64x4 = __lasx_xvreplgr2vr_d(0);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
NK_INTERNAL void nk_dot_f32x8_update_loongsonasx(nk_dot_f32x8_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
301
|
+
nk_b256_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
302
|
+
nk_unused_(depth_offset);
|
|
303
|
+
nk_unused_(active_dimensions);
|
|
304
|
+
__m256d a_low_f64x4 = __lasx_xvfcvtl_d_s(a.ymm_ps);
|
|
305
|
+
__m256d b_low_f64x4 = __lasx_xvfcvtl_d_s(b.ymm_ps);
|
|
306
|
+
state->sum_f64x4 = (__m256i)__lasx_xvfmadd_d(a_low_f64x4, b_low_f64x4, (__m256d)state->sum_f64x4);
|
|
307
|
+
__m256d a_high_f64x4 = __lasx_xvfcvth_d_s(a.ymm_ps);
|
|
308
|
+
__m256d b_high_f64x4 = __lasx_xvfcvth_d_s(b.ymm_ps);
|
|
309
|
+
state->sum_f64x4 = (__m256i)__lasx_xvfmadd_d(a_high_f64x4, b_high_f64x4, (__m256d)state->sum_f64x4);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
NK_INTERNAL void nk_dot_f32x8_finalize_loongsonasx( //
|
|
313
|
+
nk_dot_f32x8_state_loongsonasx_t const *state_a, nk_dot_f32x8_state_loongsonasx_t const *state_b, //
|
|
314
|
+
nk_dot_f32x8_state_loongsonasx_t const *state_c, nk_dot_f32x8_state_loongsonasx_t const *state_d, //
|
|
315
|
+
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
316
|
+
nk_unused_(total_dimensions);
|
|
317
|
+
// Horizontal reduction: 4 f64s → 1 f64 for each state, packed into result via SIMD
|
|
318
|
+
__m256d sum_a_f64x4 = (__m256d)state_a->sum_f64x4;
|
|
319
|
+
__m256d sum_b_f64x4 = (__m256d)state_b->sum_f64x4;
|
|
320
|
+
__m256d sum_c_f64x4 = (__m256d)state_c->sum_f64x4;
|
|
321
|
+
__m256d sum_d_f64x4 = (__m256d)state_d->sum_f64x4;
|
|
322
|
+
|
|
323
|
+
// 4 → 2: add high 128-bit lane to low lane
|
|
324
|
+
__m256d sum_a_f64x2 = __lasx_xvfadd_d(sum_a_f64x4,
|
|
325
|
+
(__m256d)__lasx_xvpermi_q((__m256i)sum_a_f64x4, (__m256i)sum_a_f64x4, 0x11));
|
|
326
|
+
__m256d sum_b_f64x2 = __lasx_xvfadd_d(sum_b_f64x4,
|
|
327
|
+
(__m256d)__lasx_xvpermi_q((__m256i)sum_b_f64x4, (__m256i)sum_b_f64x4, 0x11));
|
|
328
|
+
__m256d sum_c_f64x2 = __lasx_xvfadd_d(sum_c_f64x4,
|
|
329
|
+
(__m256d)__lasx_xvpermi_q((__m256i)sum_c_f64x4, (__m256i)sum_c_f64x4, 0x11));
|
|
330
|
+
__m256d sum_d_f64x2 = __lasx_xvfadd_d(sum_d_f64x4,
|
|
331
|
+
(__m256d)__lasx_xvpermi_q((__m256i)sum_d_f64x4, (__m256i)sum_d_f64x4, 0x11));
|
|
332
|
+
|
|
333
|
+
// 2 → 1: interleave then horizontal add (xvilvl_d/xvilvh_d are integer intrinsics)
|
|
334
|
+
__m256d ab_low_f64x2 = (__m256d)__lasx_xvilvl_d((__m256i)sum_b_f64x2, (__m256i)sum_a_f64x2);
|
|
335
|
+
__m256d ab_high_f64x2 = (__m256d)__lasx_xvilvh_d((__m256i)sum_b_f64x2, (__m256i)sum_a_f64x2);
|
|
336
|
+
__m256d cd_low_f64x2 = (__m256d)__lasx_xvilvl_d((__m256i)sum_d_f64x2, (__m256i)sum_c_f64x2);
|
|
337
|
+
__m256d cd_high_f64x2 = (__m256d)__lasx_xvilvh_d((__m256i)sum_d_f64x2, (__m256i)sum_c_f64x2);
|
|
338
|
+
__m256d sum_ab_f64x2 = __lasx_xvfadd_d(ab_low_f64x2, ab_high_f64x2);
|
|
339
|
+
__m256d sum_cd_f64x2 = __lasx_xvfadd_d(cd_low_f64x2, cd_high_f64x2);
|
|
340
|
+
|
|
341
|
+
// Pack [sum_a, sum_b, sum_c, sum_d] into one 256-bit result
|
|
342
|
+
result->ymm = __lasx_xvpermi_q((__m256i)sum_cd_f64x2, (__m256i)sum_ab_f64x2, 0x20);
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
#pragma endregion F32 and F64 Floats
|
|
346
|
+
|
|
347
|
+
#pragma region I8 and U8 Integers
|
|
348
|
+
|
|
349
|
+
/**
|
|
350
|
+
* @brief Internal helper state for dot-products of integer types, where 32-bit accumulation is enough.
|
|
351
|
+
* @sa nk_dot_i8x16_state_loongsonasx_t, nk_dot_u8x16_state_loongsonasx_t
|
|
352
|
+
*/
|
|
353
|
+
typedef struct nk_dot_through_i32_state_loongsonasx_t_ {
|
|
354
|
+
__m256i sum_i32x8;
|
|
355
|
+
} nk_dot_through_i32_state_loongsonasx_t_;
|
|
356
|
+
|
|
357
|
+
/**
|
|
358
|
+
* @brief Initializes 32-bit accumulators for integer dot-products.
|
|
359
|
+
* @sa nk_dot_i8x16_update_loongsonasx, nk_dot_u8x16_update_loongsonasx
|
|
360
|
+
*/
|
|
361
|
+
NK_INTERNAL void nk_dot_through_i32_init_loongsonasx_(nk_dot_through_i32_state_loongsonasx_t_ *state) {
|
|
362
|
+
state->sum_i32x8 = __lasx_xvreplgr2vr_w(0);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
/**
|
|
366
|
+
* @brief Finalizes 4x integer dot-products placing them into 4x consecutive 32-bit slots.
|
|
367
|
+
* @sa nk_dot_i8x16_update_loongsonasx, nk_dot_u8x16_update_loongsonasx
|
|
368
|
+
*/
|
|
369
|
+
NK_INTERNAL void nk_dot_through_i32_finalize_loongsonasx_( //
|
|
370
|
+
nk_dot_through_i32_state_loongsonasx_t_ const *state_a, nk_dot_through_i32_state_loongsonasx_t_ const *state_b, //
|
|
371
|
+
nk_dot_through_i32_state_loongsonasx_t_ const *state_c, nk_dot_through_i32_state_loongsonasx_t_ const *state_d, //
|
|
372
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
373
|
+
nk_unused_(total_dimensions);
|
|
374
|
+
// ILP-optimized 4-way horizontal reduction for i32 using LASX interleave
|
|
375
|
+
// Step 1: 8 → 4 for all 4 states (add high 128-bit lane to low lane)
|
|
376
|
+
__m256i sum_a_i32x4 = __lasx_xvadd_w(state_a->sum_i32x8,
|
|
377
|
+
__lasx_xvpermi_q(state_a->sum_i32x8, state_a->sum_i32x8, 0x11));
|
|
378
|
+
__m256i sum_b_i32x4 = __lasx_xvadd_w(state_b->sum_i32x8,
|
|
379
|
+
__lasx_xvpermi_q(state_b->sum_i32x8, state_b->sum_i32x8, 0x11));
|
|
380
|
+
__m256i sum_c_i32x4 = __lasx_xvadd_w(state_c->sum_i32x8,
|
|
381
|
+
__lasx_xvpermi_q(state_c->sum_i32x8, state_c->sum_i32x8, 0x11));
|
|
382
|
+
__m256i sum_d_i32x4 = __lasx_xvadd_w(state_d->sum_i32x8,
|
|
383
|
+
__lasx_xvpermi_q(state_d->sum_i32x8, state_d->sum_i32x8, 0x11));
|
|
384
|
+
// Step 2: Transpose 4×4 matrix via interleave
|
|
385
|
+
__m256i transpose_ab_low_i32x4 = __lasx_xvilvl_w(sum_b_i32x4, sum_a_i32x4);
|
|
386
|
+
__m256i transpose_cd_low_i32x4 = __lasx_xvilvl_w(sum_d_i32x4, sum_c_i32x4);
|
|
387
|
+
__m256i transpose_ab_high_i32x4 = __lasx_xvilvh_w(sum_b_i32x4, sum_a_i32x4);
|
|
388
|
+
__m256i transpose_cd_high_i32x4 = __lasx_xvilvh_w(sum_d_i32x4, sum_c_i32x4);
|
|
389
|
+
__m256i sum_lane0_i32x4 = __lasx_xvilvl_d(transpose_cd_low_i32x4, transpose_ab_low_i32x4);
|
|
390
|
+
__m256i sum_lane1_i32x4 = __lasx_xvilvh_d(transpose_cd_low_i32x4, transpose_ab_low_i32x4);
|
|
391
|
+
__m256i sum_lane2_i32x4 = __lasx_xvilvl_d(transpose_cd_high_i32x4, transpose_ab_high_i32x4);
|
|
392
|
+
__m256i sum_lane3_i32x4 = __lasx_xvilvh_d(transpose_cd_high_i32x4, transpose_ab_high_i32x4);
|
|
393
|
+
// Step 3: Vertical sum
|
|
394
|
+
__m256i sum_i32x4 = __lasx_xvadd_w(__lasx_xvadd_w(sum_lane0_i32x4, sum_lane1_i32x4),
|
|
395
|
+
__lasx_xvadd_w(sum_lane2_i32x4, sum_lane3_i32x4));
|
|
396
|
+
result->xmm = nk_lasx_castsi256_si128_(sum_i32x4);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
/**
|
|
400
|
+
* @brief Running state for 128-bit dot accumulation over i8 scalars on LASX.
|
|
401
|
+
* @note Alias of nk_dot_through_i32_state_loongsonasx_t_
|
|
402
|
+
*/
|
|
403
|
+
typedef struct nk_dot_through_i32_state_loongsonasx_t_ nk_dot_i8x32_state_loongsonasx_t;
|
|
404
|
+
|
|
405
|
+
NK_INTERNAL void nk_dot_i8x32_init_loongsonasx(nk_dot_i8x32_state_loongsonasx_t *state) {
|
|
406
|
+
nk_dot_through_i32_init_loongsonasx_(state);
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
NK_INTERNAL void nk_dot_i8x32_update_loongsonasx(nk_dot_i8x32_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
410
|
+
nk_b256_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
411
|
+
nk_unused_(depth_offset);
|
|
412
|
+
nk_unused_(active_dimensions);
|
|
413
|
+
__m256i acc_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
414
|
+
acc_i16x16 = __lasx_xvmaddwev_h_b(acc_i16x16, a.ymm, b.ymm);
|
|
415
|
+
acc_i16x16 = __lasx_xvmaddwod_h_b(acc_i16x16, a.ymm, b.ymm);
|
|
416
|
+
__m256i widened_i32x8 = __lasx_xvhaddw_w_h(acc_i16x16, acc_i16x16);
|
|
417
|
+
state->sum_i32x8 = __lasx_xvadd_w(state->sum_i32x8, widened_i32x8);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
NK_INTERNAL void nk_dot_i8x32_finalize_loongsonasx( //
|
|
421
|
+
nk_dot_i8x32_state_loongsonasx_t const *state_a, nk_dot_i8x32_state_loongsonasx_t const *state_b, //
|
|
422
|
+
nk_dot_i8x32_state_loongsonasx_t const *state_c, nk_dot_i8x32_state_loongsonasx_t const *state_d, //
|
|
423
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
424
|
+
nk_dot_through_i32_finalize_loongsonasx_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
/**
|
|
428
|
+
* @brief Running state for 256-bit dot accumulation over u8 scalars on LASX.
|
|
429
|
+
* @note Alias of nk_dot_through_i32_state_loongsonasx_t_
|
|
430
|
+
*/
|
|
431
|
+
typedef struct nk_dot_through_i32_state_loongsonasx_t_ nk_dot_u8x32_state_loongsonasx_t;
|
|
432
|
+
|
|
433
|
+
NK_INTERNAL void nk_dot_u8x32_init_loongsonasx(nk_dot_u8x32_state_loongsonasx_t *state) {
|
|
434
|
+
nk_dot_through_i32_init_loongsonasx_(state);
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
NK_INTERNAL void nk_dot_u8x32_update_loongsonasx(nk_dot_u8x32_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
438
|
+
nk_b256_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
439
|
+
nk_unused_(depth_offset);
|
|
440
|
+
nk_unused_(active_dimensions);
|
|
441
|
+
__m256i acc_u16x16 = __lasx_xvreplgr2vr_h(0);
|
|
442
|
+
acc_u16x16 = __lasx_xvmaddwev_h_bu(acc_u16x16, a.ymm, b.ymm);
|
|
443
|
+
acc_u16x16 = __lasx_xvmaddwod_h_bu(acc_u16x16, a.ymm, b.ymm);
|
|
444
|
+
__m256i widened_u32x8 = __lasx_xvhaddw_wu_hu(acc_u16x16, acc_u16x16);
|
|
445
|
+
state->sum_i32x8 = __lasx_xvadd_w(state->sum_i32x8, widened_u32x8);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
NK_INTERNAL void nk_dot_u8x32_finalize_loongsonasx( //
|
|
449
|
+
nk_dot_u8x32_state_loongsonasx_t const *state_a, nk_dot_u8x32_state_loongsonasx_t const *state_b, //
|
|
450
|
+
nk_dot_u8x32_state_loongsonasx_t const *state_c, nk_dot_u8x32_state_loongsonasx_t const *state_d, //
|
|
451
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
452
|
+
nk_dot_through_i32_finalize_loongsonasx_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
#pragma endregion I8 and U8 Integers
|
|
456
|
+
|
|
457
|
+
#pragma region F16 and BF16 Floats
|
|
458
|
+
|
|
459
|
+
/**
|
|
460
|
+
* @brief Internal helper state for dot-products of low-precision types, where 32-bit accumulation is enough.
|
|
461
|
+
* @sa nk_dot_bf16x16_state_loongsonasx_t
|
|
462
|
+
*/
|
|
463
|
+
typedef struct nk_dot_through_f32_state_loongsonasx_t_ {
|
|
464
|
+
__m256i sum_f32x8;
|
|
465
|
+
} nk_dot_through_f32_state_loongsonasx_t_;
|
|
466
|
+
|
|
467
|
+
/**
|
|
468
|
+
* @brief Initializes 32-bit accumulators for low-precision dot-products.
|
|
469
|
+
* @sa nk_dot_bf16x16_init_loongsonasx
|
|
470
|
+
*/
|
|
471
|
+
NK_INTERNAL void nk_dot_through_f32_init_loongsonasx_(nk_dot_through_f32_state_loongsonasx_t_ *state) {
|
|
472
|
+
state->sum_f32x8 = __lasx_xvreplgr2vr_w(0);
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
/**
|
|
476
|
+
* @brief Fuses 32-bit multiplication and accumulation for pre-converted f32 vectors.
|
|
477
|
+
* @sa nk_dot_bf16x8_update_loongsonasx
|
|
478
|
+
*/
|
|
479
|
+
NK_INTERNAL void nk_dot_through_f32_update_loongsonasx_(nk_dot_through_f32_state_loongsonasx_t_ *state, nk_b256_vec_t a,
|
|
480
|
+
nk_b256_vec_t b, nk_size_t depth_offset,
|
|
481
|
+
nk_size_t active_dimensions) {
|
|
482
|
+
nk_unused_(depth_offset);
|
|
483
|
+
nk_unused_(active_dimensions);
|
|
484
|
+
state->sum_f32x8 = (__m256i)__lasx_xvfmadd_s(a.ymm_ps, b.ymm_ps, (__m256)state->sum_f32x8);
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
/**
|
|
488
|
+
* @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
|
|
489
|
+
* @sa nk_dot_bf16x8_finalize_loongsonasx
|
|
490
|
+
*
|
|
491
|
+
* Computes 4x horizontal reductions, each involving 8x floats, using LASX interleave instructions.
|
|
492
|
+
*/
|
|
493
|
+
NK_INTERNAL void nk_dot_through_f32_finalize_loongsonasx_( //
|
|
494
|
+
nk_dot_through_f32_state_loongsonasx_t_ const *state_a, nk_dot_through_f32_state_loongsonasx_t_ const *state_b, //
|
|
495
|
+
nk_dot_through_f32_state_loongsonasx_t_ const *state_c, nk_dot_through_f32_state_loongsonasx_t_ const *state_d, //
|
|
496
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
497
|
+
nk_unused_(total_dimensions);
|
|
498
|
+
// Step 1: 8 → 4 for all 4 states (add high 128-bit lane to low lane)
|
|
499
|
+
__m256 sum_a_f32x4 = __lasx_xvfadd_s((__m256)state_a->sum_f32x8,
|
|
500
|
+
(__m256)__lasx_xvpermi_q(state_a->sum_f32x8, state_a->sum_f32x8, 0x11));
|
|
501
|
+
__m256 sum_b_f32x4 = __lasx_xvfadd_s((__m256)state_b->sum_f32x8,
|
|
502
|
+
(__m256)__lasx_xvpermi_q(state_b->sum_f32x8, state_b->sum_f32x8, 0x11));
|
|
503
|
+
__m256 sum_c_f32x4 = __lasx_xvfadd_s((__m256)state_c->sum_f32x8,
|
|
504
|
+
(__m256)__lasx_xvpermi_q(state_c->sum_f32x8, state_c->sum_f32x8, 0x11));
|
|
505
|
+
__m256 sum_d_f32x4 = __lasx_xvfadd_s((__m256)state_d->sum_f32x8,
|
|
506
|
+
(__m256)__lasx_xvpermi_q(state_d->sum_f32x8, state_d->sum_f32x8, 0x11));
|
|
507
|
+
// Step 2: Transpose 4×4 matrix via interleave (integer intrinsics, cast at boundaries)
|
|
508
|
+
__m256i transpose_ab_low_f32x4 = __lasx_xvilvl_w((__m256i)sum_b_f32x4, (__m256i)sum_a_f32x4);
|
|
509
|
+
__m256i transpose_cd_low_f32x4 = __lasx_xvilvl_w((__m256i)sum_d_f32x4, (__m256i)sum_c_f32x4);
|
|
510
|
+
__m256i transpose_ab_high_f32x4 = __lasx_xvilvh_w((__m256i)sum_b_f32x4, (__m256i)sum_a_f32x4);
|
|
511
|
+
__m256i transpose_cd_high_f32x4 = __lasx_xvilvh_w((__m256i)sum_d_f32x4, (__m256i)sum_c_f32x4);
|
|
512
|
+
__m256i sum_lane0_f32x4 = __lasx_xvilvl_d(transpose_cd_low_f32x4, transpose_ab_low_f32x4);
|
|
513
|
+
__m256i sum_lane1_f32x4 = __lasx_xvilvh_d(transpose_cd_low_f32x4, transpose_ab_low_f32x4);
|
|
514
|
+
__m256i sum_lane2_f32x4 = __lasx_xvilvl_d(transpose_cd_high_f32x4, transpose_ab_high_f32x4);
|
|
515
|
+
__m256i sum_lane3_f32x4 = __lasx_xvilvh_d(transpose_cd_high_f32x4, transpose_ab_high_f32x4);
|
|
516
|
+
// Step 3: Vertical sum
|
|
517
|
+
__m256 sum_f32x4 = __lasx_xvfadd_s(__lasx_xvfadd_s((__m256)sum_lane0_f32x4, (__m256)sum_lane1_f32x4),
|
|
518
|
+
__lasx_xvfadd_s((__m256)sum_lane2_f32x4, (__m256)sum_lane3_f32x4));
|
|
519
|
+
result->xmm_ps = nk_lasx_castps256_ps128_(sum_f32x4);
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
/**
|
|
523
|
+
* @brief Running state for 128-bit dot accumulation over bf16 scalars on LASX.
|
|
524
|
+
* @note Alias of nk_dot_through_f32_state_loongsonasx_t_
|
|
525
|
+
*/
|
|
526
|
+
typedef struct nk_dot_through_f32_state_loongsonasx_t_ nk_dot_bf16x16_state_loongsonasx_t;
|
|
527
|
+
|
|
528
|
+
NK_INTERNAL void nk_dot_bf16x16_init_loongsonasx(nk_dot_bf16x16_state_loongsonasx_t *state) {
|
|
529
|
+
nk_dot_through_f32_init_loongsonasx_(state);
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
NK_INTERNAL void nk_dot_bf16x16_update_loongsonasx(nk_dot_bf16x16_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
533
|
+
nk_b256_vec_t b, nk_size_t depth_offset,
|
|
534
|
+
nk_size_t active_dimensions) {
|
|
535
|
+
nk_unused_(depth_offset);
|
|
536
|
+
nk_unused_(active_dimensions);
|
|
537
|
+
// Even bf16 elements → slli_epi32 by 16 places them in f32 upper bits.
|
|
538
|
+
// Odd bf16 elements → AND with 0xFFFF0000 keeps them in f32 upper bits.
|
|
539
|
+
__m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
|
|
540
|
+
__m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a.ymm, 16);
|
|
541
|
+
__m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b.ymm, 16);
|
|
542
|
+
state->sum_f32x8 = (__m256i)__lasx_xvfmadd_s(a_even_f32x8, b_even_f32x8, (__m256)state->sum_f32x8);
|
|
543
|
+
__m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a.ymm, mask_high_u32x8);
|
|
544
|
+
__m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b.ymm, mask_high_u32x8);
|
|
545
|
+
state->sum_f32x8 = (__m256i)__lasx_xvfmadd_s(a_odd_f32x8, b_odd_f32x8, (__m256)state->sum_f32x8);
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
NK_INTERNAL void nk_dot_bf16x16_finalize_loongsonasx( //
|
|
549
|
+
nk_dot_bf16x16_state_loongsonasx_t const *state_a, nk_dot_bf16x16_state_loongsonasx_t const *state_b, //
|
|
550
|
+
nk_dot_bf16x16_state_loongsonasx_t const *state_c, nk_dot_bf16x16_state_loongsonasx_t const *state_d, //
|
|
551
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
552
|
+
nk_dot_through_f32_finalize_loongsonasx_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
NK_PUBLIC void nk_dot_f16_loongsonasx(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
556
|
+
nk_f32_t *result) {
|
|
557
|
+
__m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
558
|
+
nk_size_t idx_scalars = 0;
|
|
559
|
+
for (; idx_scalars + 8 <= count_scalars; idx_scalars += 8) {
|
|
560
|
+
__m128i a_f16x8 = __lsx_vld(a_scalars + idx_scalars, 0);
|
|
561
|
+
__m128i b_f16x8 = __lsx_vld(b_scalars + idx_scalars, 0);
|
|
562
|
+
__m256 a_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(a_f16x8);
|
|
563
|
+
__m256 b_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(b_f16x8);
|
|
564
|
+
sum_f32x8 = __lasx_xvfmadd_s(a_f32x8, b_f32x8, sum_f32x8);
|
|
565
|
+
}
|
|
566
|
+
__m256 high_f32x4 = (__m256)__lasx_xvpermi_q((__m256i)sum_f32x8, (__m256i)sum_f32x8, 0x11);
|
|
567
|
+
__m256 sum_f32x4 = __lasx_xvfadd_s(sum_f32x8, high_f32x4);
|
|
568
|
+
__m256 swapped_f32x4 = (__m256)__lasx_xvshuf4i_w((__m256i)sum_f32x4, 0b01001110);
|
|
569
|
+
__m256 reduced_f32x4 = __lasx_xvfadd_s(sum_f32x4, swapped_f32x4);
|
|
570
|
+
__m256 swapped_f32x2 = (__m256)__lasx_xvshuf4i_w((__m256i)reduced_f32x4, 0b10110001);
|
|
571
|
+
__m256 reduced_f32x2 = __lasx_xvfadd_s(reduced_f32x4, swapped_f32x2);
|
|
572
|
+
nk_fui32_t c;
|
|
573
|
+
c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)reduced_f32x2, 0);
|
|
574
|
+
nk_f32_t sum = c.f;
|
|
575
|
+
for (; idx_scalars < count_scalars; ++idx_scalars) {
|
|
576
|
+
nk_f32_t a_val, b_val;
|
|
577
|
+
nk_f16_to_f32_serial(&a_scalars[idx_scalars], &a_val);
|
|
578
|
+
nk_f16_to_f32_serial(&b_scalars[idx_scalars], &b_val);
|
|
579
|
+
sum += a_val * b_val;
|
|
580
|
+
}
|
|
581
|
+
*result = sum;
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
/**
|
|
585
|
+
* @brief Running state for 128-bit dot accumulation over f16 scalars on LASX.
|
|
586
|
+
* @note Alias of nk_dot_through_f32_state_loongsonasx_t_
|
|
587
|
+
*/
|
|
588
|
+
typedef struct nk_dot_through_f32_state_loongsonasx_t_ nk_dot_f16x16_state_loongsonasx_t;
|
|
589
|
+
|
|
590
|
+
NK_INTERNAL void nk_dot_f16x16_init_loongsonasx(nk_dot_f16x16_state_loongsonasx_t *state) {
|
|
591
|
+
nk_dot_through_f32_init_loongsonasx_(state);
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
NK_INTERNAL void nk_dot_f16x16_update_loongsonasx(nk_dot_f16x16_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
595
|
+
nk_b256_vec_t b, nk_size_t depth_offset,
|
|
596
|
+
nk_size_t active_dimensions) {
|
|
597
|
+
nk_unused_(depth_offset);
|
|
598
|
+
nk_unused_(active_dimensions);
|
|
599
|
+
__m256 a_low_f32x8 = __lasx_xvfcvtl_s_h(a.ymm);
|
|
600
|
+
__m256 b_low_f32x8 = __lasx_xvfcvtl_s_h(b.ymm);
|
|
601
|
+
state->sum_f32x8 = (__m256i)__lasx_xvfmadd_s(a_low_f32x8, b_low_f32x8, (__m256)state->sum_f32x8);
|
|
602
|
+
__m256 a_high_f32x8 = __lasx_xvfcvth_s_h(a.ymm);
|
|
603
|
+
__m256 b_high_f32x8 = __lasx_xvfcvth_s_h(b.ymm);
|
|
604
|
+
state->sum_f32x8 = (__m256i)__lasx_xvfmadd_s(a_high_f32x8, b_high_f32x8, (__m256)state->sum_f32x8);
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
NK_INTERNAL void nk_dot_f16x16_finalize_loongsonasx( //
|
|
608
|
+
nk_dot_f16x16_state_loongsonasx_t const *state_a, nk_dot_f16x16_state_loongsonasx_t const *state_b, //
|
|
609
|
+
nk_dot_f16x16_state_loongsonasx_t const *state_c, nk_dot_f16x16_state_loongsonasx_t const *state_d, //
|
|
610
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
611
|
+
nk_dot_through_f32_finalize_loongsonasx_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
#pragma endregion F16 and BF16 Floats
|
|
615
|
+
|
|
616
|
+
#pragma region Binary
|
|
617
|
+
|
|
618
|
+
typedef struct nk_dot_u1x256_state_loongsonasx_t {
|
|
619
|
+
__m256i dot_count_u32x8;
|
|
620
|
+
} nk_dot_u1x256_state_loongsonasx_t;
|
|
621
|
+
|
|
622
|
+
NK_INTERNAL void nk_dot_u1x256_init_loongsonasx(nk_dot_u1x256_state_loongsonasx_t *state) {
|
|
623
|
+
state->dot_count_u32x8 = __lasx_xvreplgr2vr_w(0);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
NK_INTERNAL void nk_dot_u1x256_update_loongsonasx(nk_dot_u1x256_state_loongsonasx_t *state, nk_b256_vec_t a,
|
|
627
|
+
nk_b256_vec_t b, nk_size_t depth_offset,
|
|
628
|
+
nk_size_t active_dimensions) {
|
|
629
|
+
nk_unused_(depth_offset);
|
|
630
|
+
nk_unused_(active_dimensions);
|
|
631
|
+
__m256i and_u8x32 = __lasx_xvand_v(a.ymm, b.ymm);
|
|
632
|
+
state->dot_count_u32x8 = __lasx_xvadd_w(state->dot_count_u32x8, __lasx_xvpcnt_w(and_u8x32));
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
NK_INTERNAL void nk_dot_u1x256_finalize_loongsonasx( //
|
|
636
|
+
nk_dot_u1x256_state_loongsonasx_t const *state_a, nk_dot_u1x256_state_loongsonasx_t const *state_b, //
|
|
637
|
+
nk_dot_u1x256_state_loongsonasx_t const *state_c, nk_dot_u1x256_state_loongsonasx_t const *state_d, //
|
|
638
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
639
|
+
nk_unused_(total_dimensions);
|
|
640
|
+
// Step 1: Fold 8→4 in 256-bit (add high 128-bit lane to low), extract low 128 bits
|
|
641
|
+
__m128i sum_a_u32x4 = nk_lasx_castsi256_si128_(__lasx_xvadd_w(
|
|
642
|
+
state_a->dot_count_u32x8, __lasx_xvpermi_q(state_a->dot_count_u32x8, state_a->dot_count_u32x8, 0x11)));
|
|
643
|
+
__m128i sum_b_u32x4 = nk_lasx_castsi256_si128_(__lasx_xvadd_w(
|
|
644
|
+
state_b->dot_count_u32x8, __lasx_xvpermi_q(state_b->dot_count_u32x8, state_b->dot_count_u32x8, 0x11)));
|
|
645
|
+
__m128i sum_c_u32x4 = nk_lasx_castsi256_si128_(__lasx_xvadd_w(
|
|
646
|
+
state_c->dot_count_u32x8, __lasx_xvpermi_q(state_c->dot_count_u32x8, state_c->dot_count_u32x8, 0x11)));
|
|
647
|
+
__m128i sum_d_u32x4 = nk_lasx_castsi256_si128_(__lasx_xvadd_w(
|
|
648
|
+
state_d->dot_count_u32x8, __lasx_xvpermi_q(state_d->dot_count_u32x8, state_d->dot_count_u32x8, 0x11)));
|
|
649
|
+
// Step 2: Transpose 4×4 in 128-bit via LSX interleave
|
|
650
|
+
__m128i transpose_ab_low_u32x4 = __lsx_vilvl_w(sum_b_u32x4, sum_a_u32x4);
|
|
651
|
+
__m128i transpose_cd_low_u32x4 = __lsx_vilvl_w(sum_d_u32x4, sum_c_u32x4);
|
|
652
|
+
__m128i transpose_ab_high_u32x4 = __lsx_vilvh_w(sum_b_u32x4, sum_a_u32x4);
|
|
653
|
+
__m128i transpose_cd_high_u32x4 = __lsx_vilvh_w(sum_d_u32x4, sum_c_u32x4);
|
|
654
|
+
__m128i sum_lane0_u32x4 = __lsx_vilvl_d(transpose_cd_low_u32x4, transpose_ab_low_u32x4);
|
|
655
|
+
__m128i sum_lane1_u32x4 = __lsx_vilvh_d(transpose_cd_low_u32x4, transpose_ab_low_u32x4);
|
|
656
|
+
__m128i sum_lane2_u32x4 = __lsx_vilvl_d(transpose_cd_high_u32x4, transpose_ab_high_u32x4);
|
|
657
|
+
__m128i sum_lane3_u32x4 = __lsx_vilvh_d(transpose_cd_high_u32x4, transpose_ab_high_u32x4);
|
|
658
|
+
// Step 3: Vertical sum in 128-bit
|
|
659
|
+
result->xmm = __lsx_vadd_w(__lsx_vadd_w(sum_lane0_u32x4, sum_lane1_u32x4),
|
|
660
|
+
__lsx_vadd_w(sum_lane2_u32x4, sum_lane3_u32x4));
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
#pragma endregion Binary
|
|
664
|
+
|
|
665
|
+
#if defined(__cplusplus)
|
|
666
|
+
} // extern "C"
|
|
667
|
+
#endif
|
|
668
|
+
|
|
669
|
+
#endif // NK_TARGET_LOONGSONASX
|
|
670
|
+
#endif // NK_TARGET_LOONGARCH_
|
|
671
|
+
#endif // NK_DOT_LOONGSONASX_H
|