numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -6,11 +6,11 @@
|
|
|
6
6
|
*
|
|
7
7
|
* @section ice_cast_instructions AVX-512 VBMI2 Instructions
|
|
8
8
|
*
|
|
9
|
-
* Intrinsic
|
|
10
|
-
* _mm512_permutex2var_epi16
|
|
11
|
-
* _mm512_test_epi16_mask
|
|
12
|
-
* _mm512_mask_mov_epi16
|
|
13
|
-
* _mm512_cvtepi16_epi8
|
|
9
|
+
* Intrinsic Instruction Icelake Genoa
|
|
10
|
+
* _mm512_permutex2var_epi16 VPERMI2W (ZMM, ZMM, ZMM) 3cy @ p5 2cy @ p12
|
|
11
|
+
* _mm512_test_epi16_mask VPTESTMW (k, ZMM, ZMM) 3cy @ p5 2cy @ p01
|
|
12
|
+
* _mm512_mask_mov_epi16 VMOVDQU16 (ZMM{k}, ZMM) 1cy @ p05 1cy @ p05
|
|
13
|
+
* _mm512_cvtepi16_epi8 VPMOVWB (YMM, ZMM) 3cy @ p5 2cy @ p12
|
|
14
14
|
*
|
|
15
15
|
* Ice Lake's AVX-512 VBMI2 enables efficient 128-entry LUT lookups via dual VPERMI2W operations.
|
|
16
16
|
* FP8-to-BF16/F16 conversions use 4 ZMM LUT registers with VPTESTMW for range selection, achieving
|
|
@@ -37,7 +37,7 @@ extern "C" {
|
|
|
37
37
|
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
38
38
|
#endif
|
|
39
39
|
|
|
40
|
-
#pragma region
|
|
40
|
+
#pragma region Vectorized Conversions
|
|
41
41
|
|
|
42
42
|
/** @brief Convert 32x e4m3 → 32x bf16 via arithmetic + 8-entry subnormal LUT (AVX-512BW).
|
|
43
43
|
* E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
@@ -72,7 +72,12 @@ NK_INTERNAL __m512i nk_e4m3x32_to_bf16x32_icelake_(__m256i e4m3x32) {
|
|
|
72
72
|
|
|
73
73
|
// Apply sign: shift E4M3 bit 7 to BF16 bit 15
|
|
74
74
|
sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
|
|
75
|
-
|
|
75
|
+
__m512i result_i16x32 = _mm512_or_si512(result_abs_i16x32, sign_i16x32);
|
|
76
|
+
|
|
77
|
+
// NaN: E4M3FN has NaN only at magnitude 0x7F → BF16 quiet NaN (0x7FC0)
|
|
78
|
+
__mmask32 is_nan = _mm512_cmpeq_epi16_mask(lower7_i16x32, _mm512_set1_epi16(0x7F));
|
|
79
|
+
__m512i nan_i16x32 = _mm512_or_si512(sign_i16x32, _mm512_set1_epi16(0x7FC0));
|
|
80
|
+
return _mm512_mask_blend_epi16(is_nan, result_i16x32, nan_i16x32);
|
|
76
81
|
}
|
|
77
82
|
|
|
78
83
|
/** @brief Convert 32x e5m2 → 32x bf16 via arithmetic + 4-entry subnormal LUT (AVX-512BW).
|
|
@@ -268,14 +273,14 @@ NK_INTERNAL __m256i nk_bf16x32_to_e4m3x32_icelake_(__m512i bf16x32) {
|
|
|
268
273
|
// bf16 to f32 is just left shift by 16
|
|
269
274
|
__m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
|
|
270
275
|
__m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
|
|
271
|
-
__m512
|
|
272
|
-
__m512
|
|
273
|
-
__m512
|
|
274
|
-
__m512
|
|
275
|
-
__m512
|
|
276
|
-
__m512
|
|
277
|
-
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(
|
|
278
|
-
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(
|
|
276
|
+
__m512 f32_low_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
|
|
277
|
+
__m512 f32_high_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
|
|
278
|
+
__m512 abs_f32_low_f32x16 = _mm512_and_ps(f32_low_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
279
|
+
__m512 abs_f32_high_f32x16 = _mm512_and_ps(f32_high_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
280
|
+
__m512 scaled_low_f32x16 = _mm512_mul_ps(abs_f32_low_f32x16, _mm512_set1_ps(512.0f));
|
|
281
|
+
__m512 scaled_high_f32x16 = _mm512_mul_ps(abs_f32_high_f32x16, _mm512_set1_ps(512.0f));
|
|
282
|
+
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low_f32x16);
|
|
283
|
+
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high_f32x16);
|
|
279
284
|
__m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
|
|
280
285
|
__m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
|
|
281
286
|
__m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
|
|
@@ -328,14 +333,14 @@ NK_INTERNAL __m256i nk_bf16x32_to_e5m2x32_icelake_(__m512i bf16x32) {
|
|
|
328
333
|
// Subnormal path: compute via f32 to get correct rounding
|
|
329
334
|
__m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
|
|
330
335
|
__m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
|
|
331
|
-
__m512
|
|
332
|
-
__m512
|
|
333
|
-
__m512
|
|
334
|
-
__m512
|
|
335
|
-
__m512
|
|
336
|
-
__m512
|
|
337
|
-
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(
|
|
338
|
-
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(
|
|
336
|
+
__m512 f32_low_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
|
|
337
|
+
__m512 f32_high_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
|
|
338
|
+
__m512 abs_f32_low_f32x16 = _mm512_and_ps(f32_low_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
339
|
+
__m512 abs_f32_high_f32x16 = _mm512_and_ps(f32_high_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
|
|
340
|
+
__m512 scaled_low_f32x16 = _mm512_mul_ps(abs_f32_low_f32x16, _mm512_set1_ps(65536.0f));
|
|
341
|
+
__m512 scaled_high_f32x16 = _mm512_mul_ps(abs_f32_high_f32x16, _mm512_set1_ps(65536.0f));
|
|
342
|
+
__m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low_f32x16);
|
|
343
|
+
__m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high_f32x16);
|
|
339
344
|
__m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
|
|
340
345
|
__m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
|
|
341
346
|
__m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
|
|
@@ -362,8 +367,8 @@ NK_INTERNAL void nk_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
|
|
|
362
367
|
/** @brief Partial load n e4m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
363
368
|
NK_INTERNAL void nk_partial_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
364
369
|
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
365
|
-
__m256i
|
|
366
|
-
dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(
|
|
370
|
+
__m256i e4m3_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
|
|
371
|
+
dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(e4m3_partial_i8x32);
|
|
367
372
|
}
|
|
368
373
|
|
|
369
374
|
/** @brief Load 32x e5m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
@@ -374,8 +379,8 @@ NK_INTERNAL void nk_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
|
|
|
374
379
|
/** @brief Partial load n e5m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
375
380
|
NK_INTERNAL void nk_partial_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
376
381
|
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
377
|
-
__m256i
|
|
378
|
-
dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(
|
|
382
|
+
__m256i e5m2_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
|
|
383
|
+
dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(e5m2_partial_i8x32);
|
|
379
384
|
}
|
|
380
385
|
|
|
381
386
|
/** @brief Load 32x e2m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
@@ -386,8 +391,8 @@ NK_INTERNAL void nk_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
|
|
|
386
391
|
/** @brief Partial load n e2m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
387
392
|
NK_INTERNAL void nk_partial_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
388
393
|
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
389
|
-
__m256i
|
|
390
|
-
dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(
|
|
394
|
+
__m256i e2m3_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
|
|
395
|
+
dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(e2m3_partial_i8x32);
|
|
391
396
|
}
|
|
392
397
|
|
|
393
398
|
/** @brief Load 32x e3m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
|
|
@@ -398,13 +403,13 @@ NK_INTERNAL void nk_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
|
|
|
398
403
|
/** @brief Partial load n e3m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
|
|
399
404
|
NK_INTERNAL void nk_partial_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
|
|
400
405
|
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
401
|
-
__m256i
|
|
402
|
-
dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(
|
|
406
|
+
__m256i e3m2_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
|
|
407
|
+
dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(e3m2_partial_i8x32);
|
|
403
408
|
}
|
|
404
409
|
|
|
405
|
-
#pragma endregion
|
|
410
|
+
#pragma endregion Vectorized Conversions
|
|
406
411
|
|
|
407
|
-
#pragma region
|
|
412
|
+
#pragma region Public API
|
|
408
413
|
|
|
409
414
|
NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
410
415
|
// Group 1: Conversions to bf16 (e4m3 → bf16, e5m2 → bf16)
|
|
@@ -428,9 +433,9 @@ NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
428
433
|
for (nk_size_t idx = 0; idx < n; idx += 32) {
|
|
429
434
|
nk_size_t remaining = n - idx;
|
|
430
435
|
__mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
|
|
431
|
-
__m512i
|
|
432
|
-
__m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(
|
|
433
|
-
: nk_bf16x32_to_e5m2x32_icelake_(
|
|
436
|
+
__m512i in_bf16x32_i16x32 = _mm512_maskz_loadu_epi16(mask, from_ptr + idx);
|
|
437
|
+
__m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(in_bf16x32_i16x32)
|
|
438
|
+
: nk_bf16x32_to_e5m2x32_icelake_(in_bf16x32_i16x32);
|
|
434
439
|
_mm256_mask_storeu_epi8(to_ptr + idx, mask, out_f8x32);
|
|
435
440
|
}
|
|
436
441
|
}
|
|
@@ -453,7 +458,7 @@ NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
453
458
|
else nk_cast_skylake(from, from_type, n, to, to_type);
|
|
454
459
|
}
|
|
455
460
|
|
|
456
|
-
#pragma endregion
|
|
461
|
+
#pragma endregion Public API
|
|
457
462
|
|
|
458
463
|
#if defined(__clang__)
|
|
459
464
|
#pragma clang attribute pop
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions and Load/Store Helpers for LoongArch LASX (256-bit).
|
|
3
|
+
* @file include/numkong/cast/loongsonasx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/cast.h
|
|
8
|
+
*
|
|
9
|
+
* @section loongsonasx_cast_instructions Key LASX Load/Store Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Description
|
|
12
|
+
* __lasx_xvld(ptr, 0) XVLD 256-bit aligned/unaligned load
|
|
13
|
+
* __lasx_xvst(v, ptr, 0) XVST 256-bit aligned/unaligned store
|
|
14
|
+
* __lasx_xvreplgr2vr_w(bits) XVREPLGR2VR.W Broadcast i32 to 8 lanes
|
|
15
|
+
* __lasx_xvreplgr2vr_d(bits) XVREPLGR2VR.D Broadcast i64 to 4 lanes
|
|
16
|
+
* __lasx_xvffint_s_w(v) XVFFINT.S.W 4x i32 -> f32 (per 128-bit lane)
|
|
17
|
+
* __lasx_xvfrsqrt_s(v) XVFRSQRT.S f32 full-precision reciprocal sqrt
|
|
18
|
+
* __lasx_xvfsqrt_s(v) XVFSQRT.S f32 full-precision sqrt
|
|
19
|
+
* __lasx_xvfsqrt_d(v) XVFSQRT.D f64 full-precision sqrt
|
|
20
|
+
*
|
|
21
|
+
* LASX is a 256-bit extension; all vector registers are 256-bit `__m256i`. For 128-bit
|
|
22
|
+
* `nk_b128_vec_t` operations, `__lasx_xvld` safely loads into the low 128 bits (the high
|
|
23
|
+
* 128 bits are zeroed or undefined depending on context). For 128-bit stores we use `memcpy`
|
|
24
|
+
* to avoid writing beyond the intended 16 bytes. Partial loads/stores delegate to serial
|
|
25
|
+
* helpers since LASX lacks masked load/store instructions.
|
|
26
|
+
*/
|
|
27
|
+
#ifndef NK_CAST_LOONGSONASX_H
|
|
28
|
+
#define NK_CAST_LOONGSONASX_H
|
|
29
|
+
|
|
30
|
+
#if NK_TARGET_LOONGARCH_
|
|
31
|
+
#if NK_TARGET_LOONGSONASX
|
|
32
|
+
|
|
33
|
+
#include "numkong/types.h"
|
|
34
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b32x4_serial_`, `nk_partial_load_b64x4_serial_`
|
|
35
|
+
#include "numkong/scalar/loongsonasx.h" // `nk_xvreplgr2vr_s_128_`, `nk_xvfreplgr2vr_s_`
|
|
36
|
+
|
|
37
|
+
#if defined(__cplusplus)
|
|
38
|
+
extern "C" {
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
#pragma region Type Punned Loads and Stores
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* LSX and LASX share the same physical register file, so widening __m128i → __m256i and
|
|
45
|
+
* extracting __m256i → __m128i are no-ops on hardware. Empty inline asm with "f" constraints
|
|
46
|
+
* avoids the stack round-trip that union punning causes on GCC 14.
|
|
47
|
+
* Named after x86 `_mm256_castsi128_si256` / `_mm256_castsi256_si128` / `_mm256_castps256_ps128`.
|
|
48
|
+
*/
|
|
49
|
+
NK_INTERNAL __m256i nk_lasx_castsi128_si256_(__m128i low_i64x2) {
|
|
50
|
+
__m256i wide_i64x4;
|
|
51
|
+
__asm__("" : "=f"(wide_i64x4) : "f"(low_i64x2));
|
|
52
|
+
return wide_i64x4;
|
|
53
|
+
}
|
|
54
|
+
NK_INTERNAL __m128i nk_lasx_castsi256_si128_(__m256i wide_i64x4) {
|
|
55
|
+
__m128i low_i64x2;
|
|
56
|
+
__asm__("" : "=f"(low_i64x2) : "f"(wide_i64x4));
|
|
57
|
+
return low_i64x2;
|
|
58
|
+
}
|
|
59
|
+
NK_INTERNAL __m128 nk_lasx_castps256_ps128_(__m256 wide_f32x8) {
|
|
60
|
+
__m128 low_f32x4;
|
|
61
|
+
__asm__("" : "=f"(low_f32x4) : "f"(wide_f32x8));
|
|
62
|
+
return low_f32x4;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** @brief Type-agnostic 256-bit full load (LASX). */
|
|
66
|
+
NK_INTERNAL void nk_load_b256_loongsonasx_(void const *src, nk_b256_vec_t *dst) { dst->ymm = __lasx_xvld(src, 0); }
|
|
67
|
+
|
|
68
|
+
/** @brief Type-agnostic 256-bit full store (LASX). */
|
|
69
|
+
NK_INTERNAL void nk_store_b256_loongsonasx_(nk_b256_vec_t const *src, void *dst) { __lasx_xvst(src->ymm, dst, 0); }
|
|
70
|
+
|
|
71
|
+
/** @brief Type-agnostic 128-bit full load (LSX subset of LASX). */
|
|
72
|
+
NK_INTERNAL void nk_load_b128_loongsonasx_(void const *src, nk_b128_vec_t *dst) { dst->xmm = __lsx_vld(src, 0); }
|
|
73
|
+
|
|
74
|
+
/** @brief Type-agnostic 128-bit full store (LSX subset of LASX). */
|
|
75
|
+
NK_INTERNAL void nk_store_b128_loongsonasx_(nk_b128_vec_t const *src, void *dst) { __lsx_vst(src->xmm, dst, 0); }
|
|
76
|
+
|
|
77
|
+
/** @brief Convert 8 × bf16 → 8 × f32 by interleaving with zero so bf16 lands in upper 16 bits (LASX).
|
|
78
|
+
*
|
|
79
|
+
* Duplicates the 128-bit input into both lanes, then `xvilvl_h(bf16, zero)` places each bf16
|
|
80
|
+
* value in the high 16 bits of a 32-bit slot — which is valid f32 with no shift needed.
|
|
81
|
+
* `xvpermi_q` combines the low-element and high-element halves into a single register.
|
|
82
|
+
*/
|
|
83
|
+
NK_INTERNAL __m256i nk_bf16x8_to_f32x8_loongsonasx_(__m128i bf16_i16x8) {
|
|
84
|
+
__m256i duped_bf16x16 = __lasx_xvpermi_q(nk_lasx_castsi128_si256_(bf16_i16x8), nk_lasx_castsi128_si256_(bf16_i16x8),
|
|
85
|
+
0x00);
|
|
86
|
+
__m256i zero_i16x16 = __lasx_xvreplgr2vr_h(0);
|
|
87
|
+
__m256i low_f32x8 = __lasx_xvilvl_h(duped_bf16x16, zero_i16x16);
|
|
88
|
+
__m256i high_f32x8 = __lasx_xvilvh_h(duped_bf16x16, zero_i16x16);
|
|
89
|
+
return __lasx_xvpermi_q(high_f32x8, low_f32x8, 0x20);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/** @brief Load 8 × bf16 from memory, convert to 8 × f32, store in 256-bit vector (LASX). */
|
|
93
|
+
NK_INTERNAL void nk_load_bf16x8_to_f32x8_loongsonasx_(void const *src, nk_b256_vec_t *dst) {
|
|
94
|
+
dst->ymm = nk_bf16x8_to_f32x8_loongsonasx_(__lsx_vld(src, 0));
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/** @brief Partial load for bf16 elements (up to 8) with conversion to f32 (LASX). */
|
|
98
|
+
NK_INTERNAL void nk_partial_load_bf16x8_to_f32x8_loongsonasx_(nk_bf16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
99
|
+
nk_b128_vec_t vec;
|
|
100
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
101
|
+
dst->ymm = nk_bf16x8_to_f32x8_loongsonasx_(vec.xmm);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/** @brief Convert 8 × f16 → 8 × f32 via native LASX hardware conversion. */
|
|
105
|
+
NK_INTERNAL __m256i nk_f16x8_to_f32x8_loongsonasx_(__m128i f16_i16x8) {
|
|
106
|
+
__m256i duped_f16x16 = __lasx_xvpermi_q(nk_lasx_castsi128_si256_(f16_i16x8), nk_lasx_castsi128_si256_(f16_i16x8),
|
|
107
|
+
0x00);
|
|
108
|
+
__m256i low_f32x8 = (__m256i)__lasx_xvfcvtl_s_h(duped_f16x16);
|
|
109
|
+
__m256i high_f32x8 = (__m256i)__lasx_xvfcvth_s_h(duped_f16x16);
|
|
110
|
+
return __lasx_xvpermi_q(high_f32x8, low_f32x8, 0x20);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/** @brief Load 8 × f16 from memory, convert to 8 × f32 via native LASX conversion. */
|
|
114
|
+
NK_INTERNAL void nk_load_f16x8_to_f32x8_loongsonasx_(void const *src, nk_b256_vec_t *dst) {
|
|
115
|
+
dst->ymm = nk_f16x8_to_f32x8_loongsonasx_(__lsx_vld(src, 0));
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
/** @brief Partial load for f16 elements (up to 8) with conversion to f32 (LASX). */
|
|
119
|
+
NK_INTERNAL void nk_partial_load_f16x8_to_f32x8_loongsonasx_(nk_f16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
120
|
+
nk_b128_vec_t vec;
|
|
121
|
+
nk_partial_load_b16x8_serial_(src, &vec, n);
|
|
122
|
+
dst->ymm = nk_f16x8_to_f32x8_loongsonasx_(vec.xmm);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
#pragma endregion Type Punned Loads and Stores
|
|
126
|
+
|
|
127
|
+
#pragma region Vectorized From Dot Helpers
|
|
128
|
+
|
|
129
|
+
/** @brief Safe square root of 8 floats with zero-clamping for numerical stability (LASX 256-bit). */
|
|
130
|
+
NK_INTERNAL __m256 nk_sqrt_f32x8_loongsonasx_(__m256 x_f32x8) {
|
|
131
|
+
__m256 zero_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
|
|
132
|
+
return __lasx_xvfsqrt_s(__lasx_xvfmax_s(x_f32x8, zero_f32x8));
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/** @brief Safe square root of 4 floats with zero-clamping for numerical stability (LSX 128-bit). */
|
|
136
|
+
NK_INTERNAL __m128 nk_sqrt_f32x4_loongsonasx_(__m128 x_f32x4) {
|
|
137
|
+
__m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
|
|
138
|
+
return __lsx_vfsqrt_s(__lsx_vfmax_s(x_f32x4, zero_f32x4));
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs (LSX 128-bit f32). */
|
|
142
|
+
NK_INTERNAL void nk_angular_through_f32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
143
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
144
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
145
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_(query_sumsq);
|
|
146
|
+
__m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, target_sumsqs.xmm_ps);
|
|
147
|
+
__m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
|
|
148
|
+
__m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
|
|
149
|
+
__m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
|
|
150
|
+
__m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
|
|
151
|
+
__m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
|
|
152
|
+
results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
/** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (LSX 128-bit f32). */
|
|
156
|
+
NK_INTERNAL void nk_euclidean_through_f32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
|
|
157
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
158
|
+
__m128 dots_f32x4 = dots.xmm_ps;
|
|
159
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_(query_sumsq);
|
|
160
|
+
__m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, target_sumsqs.xmm_ps);
|
|
161
|
+
__m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
|
|
162
|
+
// dist_sq = sum_sq − 2 × dots = -(2 × dots − sum_sq)
|
|
163
|
+
__m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
164
|
+
results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
/** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs (LASX 256-bit). */
|
|
168
|
+
NK_INTERNAL void nk_angular_through_f64_from_dot_loongsonasx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
169
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
170
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
171
|
+
__m256d query_sumsq_f64x4 = nk_xvfreplgr2vr_d_(query_sumsq);
|
|
172
|
+
__m256d products_f64x4 = __lasx_xvfmul_d(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
173
|
+
__m256d sqrt_products_f64x4 = __lasx_xvfsqrt_d(products_f64x4);
|
|
174
|
+
__m256d normalized_f64x4 = __lasx_xvfdiv_d(dots_f64x4, sqrt_products_f64x4);
|
|
175
|
+
__m256d one_f64x4 = nk_xvfreplgr2vr_d_(1.0);
|
|
176
|
+
__m256d angular_f64x4 = __lasx_xvfsub_d(one_f64x4, normalized_f64x4);
|
|
177
|
+
__m256d zero_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
178
|
+
results->ymm_pd = __lasx_xvfmax_d(angular_f64x4, zero_f64x4);
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/** @brief Euclidean from_dot for native f64: √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (LASX 256-bit). */
|
|
182
|
+
NK_INTERNAL void nk_euclidean_through_f64_from_dot_loongsonasx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
|
|
183
|
+
nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
|
|
184
|
+
__m256d dots_f64x4 = dots.ymm_pd;
|
|
185
|
+
__m256d query_sumsq_f64x4 = nk_xvfreplgr2vr_d_(query_sumsq);
|
|
186
|
+
__m256d sum_sq_f64x4 = __lasx_xvfadd_d(query_sumsq_f64x4, target_sumsqs.ymm_pd);
|
|
187
|
+
__m256d two_f64x4 = nk_xvfreplgr2vr_d_(2.0);
|
|
188
|
+
// dist_sq = sum_sq − 2 × dots = -(2 × dots − sum_sq)
|
|
189
|
+
__m256d dist_sq_f64x4 = __lasx_xvfnmsub_d(two_f64x4, dots_f64x4, sum_sq_f64x4);
|
|
190
|
+
__m256d zero_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
|
|
191
|
+
results->ymm_pd = __lasx_xvfsqrt_d(__lasx_xvfmax_d(dist_sq_f64x4, zero_f64x4));
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
/** @brief Angular from_dot for i32 accumulators: cast i32 → f32, rsqrt+NR, clamp. 4 pairs (LSX 128-bit). */
|
|
195
|
+
NK_INTERNAL void nk_angular_through_i32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
196
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
197
|
+
__m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
|
|
198
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
|
|
199
|
+
__m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
|
|
200
|
+
__m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
|
|
201
|
+
__m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
|
|
202
|
+
__m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
|
|
203
|
+
__m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
|
|
204
|
+
__m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
|
|
205
|
+
results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/** @brief Euclidean from_dot for i32 accumulators: cast i32 → f32, then √(a² + b² − 2ab). 4 pairs (LSX 128-bit). */
|
|
209
|
+
NK_INTERNAL void nk_euclidean_through_i32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
|
|
210
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
211
|
+
__m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
|
|
212
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
|
|
213
|
+
__m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
|
|
214
|
+
__m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
|
|
215
|
+
__m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
216
|
+
results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/** @brief Angular from_dot for u32 accumulators: cast u32 → f32, rsqrt+NR, clamp. 4 pairs (LSX 128-bit). */
|
|
220
|
+
NK_INTERNAL void nk_angular_through_u32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
221
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
222
|
+
__m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
|
|
223
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
|
|
224
|
+
__m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
|
|
225
|
+
__m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
|
|
226
|
+
__m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
|
|
227
|
+
__m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
|
|
228
|
+
__m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
|
|
229
|
+
__m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
|
|
230
|
+
results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
/** @brief Euclidean from_dot for u32 accumulators: cast u32 → f32, then √(a² + b² − 2ab). 4 pairs (LSX 128-bit). */
|
|
234
|
+
NK_INTERNAL void nk_euclidean_through_u32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
|
|
235
|
+
nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
|
|
236
|
+
__m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
|
|
237
|
+
__m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
|
|
238
|
+
__m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
|
|
239
|
+
__m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
|
|
240
|
+
__m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
|
|
241
|
+
results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
#pragma endregion Vectorized From Dot Helpers
|
|
245
|
+
|
|
246
|
+
#if defined(__cplusplus)
|
|
247
|
+
} // extern "C"
|
|
248
|
+
#endif
|
|
249
|
+
|
|
250
|
+
#endif // NK_TARGET_LOONGSONASX
|
|
251
|
+
#endif // NK_TARGET_LOONGARCH_
|
|
252
|
+
#endif // NK_CAST_LOONGSONASX_H
|