cui-llama.rn 1.4.3 → 1.4.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +93 -114
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/build-arm64/CMakeCache.txt +429 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +21 -21
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +101 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
- package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +376 -0
- package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +16 -0
- package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +165 -0
- package/android/src/main/build-arm64/CMakeFiles/Makefile2 +297 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/1 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/2 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/3 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/4 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/5 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/6 +1 -0
- package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +1 -0
- package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +8 -0
- package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +1 -0
- package/android/src/main/build-arm64/CMakeFiles/progress.marks +1 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +58 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +756 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +709 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +714 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +708 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +113 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +713 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +763 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +61 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +707 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +104 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +714 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +723 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +41 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +41 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +41 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +41 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +41 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +62 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +722 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +89 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +2 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +17 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +41 -0
- package/android/src/main/build-arm64/Makefile +1862 -0
- package/android/src/main/build-arm64/cmake_install.cmake +66 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
- package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
- package/android/src/main/jni-utils.h +6 -0
- package/android/src/main/jni.cpp +287 -31
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
- package/cpp/chat-template.hpp +529 -0
- package/cpp/chat.cpp +1085 -0
- package/cpp/chat.hpp +55 -0
- package/cpp/common.cpp +159 -36
- package/cpp/common.h +64 -19
- package/cpp/ggml-alloc.c +1 -13
- package/cpp/ggml-common.h +0 -2
- package/cpp/ggml-cpu-impl.h +6 -12
- package/cpp/ggml-cpu-quants.c +937 -340
- package/cpp/ggml-cpu.c +207 -113
- package/cpp/ggml-cpu.cpp +4 -6
- package/cpp/ggml-cpu.h +1 -1
- package/cpp/ggml-metal.h +66 -66
- package/cpp/ggml-metal.m +141 -23
- package/cpp/ggml.c +24 -14
- package/cpp/ggml.h +2 -2
- package/cpp/json-schema-to-grammar.cpp +46 -66
- package/cpp/json-schema-to-grammar.h +15 -1
- package/cpp/llama-arch.cpp +7 -2
- package/cpp/llama-arch.h +3 -1
- package/cpp/llama-chat.cpp +10 -1
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-grammar.cpp +86 -6
- package/cpp/llama-grammar.h +22 -1
- package/cpp/llama-impl.h +6 -6
- package/cpp/llama-kv-cache.h +1 -1
- package/cpp/llama-mmap.h +1 -0
- package/cpp/llama-model-loader.cpp +1 -1
- package/cpp/llama-model.cpp +32 -6
- package/cpp/llama-sampling.cpp +178 -61
- package/cpp/llama-vocab.cpp +8 -3
- package/cpp/llama.cpp +188 -128
- package/cpp/llama.h +27 -10
- package/cpp/log.cpp +32 -10
- package/cpp/log.h +12 -1
- package/cpp/minja.hpp +2883 -0
- package/cpp/rn-llama.cpp +82 -5
- package/cpp/rn-llama.h +16 -1
- package/cpp/sampling.cpp +68 -41
- package/cpp/sampling.h +3 -0
- package/cpp/sgemm.cpp +9 -8
- package/cpp/unicode.cpp +9 -2
- package/ios/CMakeLists.txt +6 -0
- package/ios/RNLlama.h +0 -8
- package/ios/RNLlama.mm +27 -3
- package/ios/RNLlamaContext.h +10 -1
- package/ios/RNLlamaContext.mm +269 -57
- package/jest/mock.js +21 -2
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/grammar.js +3 -0
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +87 -13
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/grammar.js +3 -0
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +86 -13
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +107 -2
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/grammar.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +32 -7
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +1 -1
- package/package.json +3 -2
- package/src/NativeRNLlama.ts +115 -3
- package/src/grammar.ts +3 -0
- package/src/index.ts +138 -21
package/cpp/ggml-cpu-quants.c
CHANGED
@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
|
297
297
|
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
298
298
|
#endif
|
299
299
|
|
300
|
+
#if defined(__loongarch_sx)
|
301
|
+
|
302
|
+
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
303
|
+
__m128i tmp, tmp1;
|
304
|
+
tmp = __lsx_vsat_w(a, 15);
|
305
|
+
tmp1 = __lsx_vsat_w(b, 15);
|
306
|
+
return __lsx_vpickev_h(tmp1, tmp);
|
307
|
+
}
|
308
|
+
|
309
|
+
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
310
|
+
__m128i tmp, tmp1;
|
311
|
+
tmp = __lsx_vsat_h(a, 7);
|
312
|
+
tmp1 = __lsx_vsat_h(b, 7);
|
313
|
+
return __lsx_vpickev_b(tmp1, tmp);
|
314
|
+
}
|
315
|
+
|
316
|
+
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
317
|
+
__m128i tmp, tmp1;
|
318
|
+
tmp = __lsx_vsat_hu(a, 7);
|
319
|
+
tmp1 = __lsx_vsat_hu(b, 7);
|
320
|
+
return __lsx_vpickev_b(tmp1, tmp);
|
321
|
+
}
|
322
|
+
|
323
|
+
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
324
|
+
__m128i tmp1, tmp2;
|
325
|
+
tmp1 = __lsx_vmulwev_h_b(a, b);
|
326
|
+
tmp2 = __lsx_vmulwod_h_b(a, b);
|
327
|
+
return __lsx_vsadd_h(tmp1, tmp2);
|
328
|
+
}
|
329
|
+
|
330
|
+
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
331
|
+
__m128i tmp1, tmp2;
|
332
|
+
tmp1 = __lsx_vmulwev_w_h(a, b);
|
333
|
+
tmp2 = __lsx_vmulwod_w_h(a, b);
|
334
|
+
return __lsx_vadd_w(tmp1, tmp2);
|
335
|
+
}
|
336
|
+
|
337
|
+
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
338
|
+
v4i32 __ret = {d, c, b, a};
|
339
|
+
return (__m128i)__ret;
|
340
|
+
}
|
341
|
+
|
342
|
+
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
343
|
+
__m128i mask_f, zero, tmp0, tmp2, mask;
|
344
|
+
int f = 0x8f;
|
345
|
+
mask_f = __lsx_vreplgr2vr_b(f);
|
346
|
+
zero = __lsx_vldi(0);
|
347
|
+
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
348
|
+
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
349
|
+
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
350
|
+
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
351
|
+
return __lsx_vshuf_b(a, zero, tmp2);
|
352
|
+
}
|
353
|
+
|
354
|
+
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
355
|
+
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
356
|
+
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
357
|
+
return __lsx_vadd_h(tmp1, tmp2);
|
358
|
+
}
|
359
|
+
|
360
|
+
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
361
|
+
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
362
|
+
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
363
|
+
return __lsx_vadd_w(tmp1, tmp2);
|
364
|
+
}
|
365
|
+
|
366
|
+
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
367
|
+
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
368
|
+
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
369
|
+
|
370
|
+
return __lsx_vfadd_s(tmp1, tmp2);
|
371
|
+
}
|
372
|
+
|
373
|
+
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
|
374
|
+
__m128 res_0 =lsx_hadd_s(a, b);
|
375
|
+
__m128 res_1 =lsx_hadd_s(c, d);
|
376
|
+
__m128 res =lsx_hadd_s(res_0, res_1);
|
377
|
+
res =lsx_hadd_s(res, res);
|
378
|
+
res =lsx_hadd_s(res, res);
|
379
|
+
|
380
|
+
return ((v4f32)res)[0];
|
381
|
+
}
|
382
|
+
#endif
|
383
|
+
|
300
384
|
#if defined(__loongarch_asx)
|
301
385
|
|
302
386
|
#ifdef __clang__
|
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
|
|
395
479
|
return (__m256i)__ret;
|
396
480
|
}
|
397
481
|
|
398
|
-
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
399
|
-
v4i32 __ret = {d, c, b, a};
|
400
|
-
return (__m128i)__ret;
|
401
|
-
}
|
402
|
-
|
403
482
|
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
404
483
|
v4i64 __ret = {d, c, b, a};
|
405
484
|
return (__m256i)__ret;
|
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
|
409
488
|
return lasx_set_q(x, y);
|
410
489
|
}
|
411
490
|
|
412
|
-
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
413
|
-
__m128i mask_f, zero, tmp0, tmp2, mask;
|
414
|
-
int f = 0x8f;
|
415
|
-
mask_f = __lsx_vreplgr2vr_b(f);
|
416
|
-
zero = __lsx_vldi(0);
|
417
|
-
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
418
|
-
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
419
|
-
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
420
|
-
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
421
|
-
return __lsx_vshuf_b(a, zero, tmp2);
|
422
|
-
}
|
423
|
-
|
424
491
|
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
425
492
|
__m256i mask_f, zero, tmp0, tmp2, mask;
|
426
493
|
int f = 0x8f;
|
@@ -434,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
|
434
501
|
}
|
435
502
|
|
436
503
|
static __m256i lasx_extu8_16(__m128i a) {
|
437
|
-
|
438
|
-
__m128i vlo = __lsx_vilvl_b(zero, a);
|
439
|
-
__m128i vhi = __lsx_vilvh_b(zero, a);
|
440
|
-
return lasx_set_q(vhi, vlo);
|
504
|
+
return __lasx_vext2xv_hu_bu(____m256i(a));
|
441
505
|
}
|
442
506
|
|
443
507
|
static __m256i lasx_ext8_16(__m128i a) {
|
444
|
-
|
445
|
-
__m128i vlo = __lsx_vilvl_b(sign, a);
|
446
|
-
__m128i vhi = __lsx_vilvh_b(sign, a);
|
447
|
-
return lasx_set_q(vhi, vlo);
|
508
|
+
return __lasx_vext2xv_h_b(____m256i(a));
|
448
509
|
}
|
449
510
|
|
450
511
|
static __m256i lasx_ext16_32(__m128i a) {
|
451
|
-
|
452
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
453
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
454
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
455
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
456
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
457
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
458
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
459
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
460
|
-
return tmp1;
|
512
|
+
return __lasx_vext2xv_w_h(____m256i(a));
|
461
513
|
}
|
462
514
|
|
463
515
|
static __m128i lasx_extracti128( __m256i a, int pos) {
|
@@ -482,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
|
|
482
534
|
return ret;
|
483
535
|
}
|
484
536
|
|
485
|
-
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
486
|
-
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
487
|
-
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
488
|
-
return __lsx_vadd_h(tmp1, tmp2);
|
489
|
-
}
|
490
|
-
|
491
|
-
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
492
|
-
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
493
|
-
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
494
|
-
return __lsx_vadd_w(tmp1, tmp2);
|
495
|
-
}
|
496
|
-
|
497
|
-
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
498
|
-
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
499
|
-
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
500
|
-
|
501
|
-
return __lsx_vfadd_s(tmp1, tmp2);
|
502
|
-
}
|
503
|
-
|
504
537
|
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
505
538
|
__m256i tmp1, tmp2;
|
506
539
|
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
@@ -529,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
|
529
562
|
return __lasx_xvpickev_b(tmp1, tmp);
|
530
563
|
}
|
531
564
|
|
532
|
-
static
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
return
|
537
|
-
}
|
538
|
-
|
539
|
-
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
540
|
-
__m128i tmp, tmp1;
|
541
|
-
tmp = __lsx_vsat_h(a, 7);
|
542
|
-
tmp1 = __lsx_vsat_h(b, 7);
|
543
|
-
return __lsx_vpickev_b(tmp1, tmp);
|
544
|
-
}
|
545
|
-
|
546
|
-
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
547
|
-
__m128i tmp, tmp1;
|
548
|
-
tmp = __lsx_vsat_hu(a, 7);
|
549
|
-
tmp1 = __lsx_vsat_hu(b, 7);
|
550
|
-
return __lsx_vpickev_b(tmp1, tmp);
|
565
|
+
static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
|
566
|
+
__m256i tmp1, tmp2;
|
567
|
+
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
568
|
+
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
569
|
+
return __lasx_xvadd_h(tmp1, tmp2);
|
551
570
|
}
|
552
571
|
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
572
|
+
static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
|
573
|
+
switch (b) {
|
574
|
+
case 0: return __lasx_xvrepl128vei_h(a, 0);
|
575
|
+
case 1: return __lasx_xvrepl128vei_h(a, 1);
|
576
|
+
case 2: return __lasx_xvrepl128vei_h(a, 2);
|
577
|
+
case 3: return __lasx_xvrepl128vei_h(a, 3);
|
578
|
+
case 4: return __lasx_xvrepl128vei_h(a, 4);
|
579
|
+
case 5: return __lasx_xvrepl128vei_h(a, 5);
|
580
|
+
case 6: return __lasx_xvrepl128vei_h(a, 6);
|
581
|
+
case 7: return __lasx_xvrepl128vei_h(a, 7);
|
582
|
+
default: __builtin_unreachable();
|
583
|
+
}
|
559
584
|
}
|
560
585
|
|
561
|
-
static
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
586
|
+
static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
|
587
|
+
switch (b) {
|
588
|
+
case 0: return __lasx_xvandi_b(a, 1 << 0);
|
589
|
+
case 1: return __lasx_xvandi_b(a, 1 << 1);
|
590
|
+
case 2: return __lasx_xvandi_b(a, 1 << 2);
|
591
|
+
case 3: return __lasx_xvandi_b(a, 1 << 3);
|
592
|
+
case 4: return __lasx_xvandi_b(a, 1 << 4);
|
593
|
+
case 5: return __lasx_xvandi_b(a, 1 << 5);
|
594
|
+
case 6: return __lasx_xvandi_b(a, 1 << 6);
|
595
|
+
case 7: return __lasx_xvandi_b(a, 1 << 7);
|
596
|
+
default: __builtin_unreachable();
|
597
|
+
}
|
566
598
|
}
|
567
599
|
|
568
600
|
// multiply int8_t, add results pairwise twice
|
@@ -580,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
|
580
612
|
// horizontally add 8 floats
|
581
613
|
static inline float hsum_float_8(const __m256 x) {
|
582
614
|
__m128 res = lasx_extractf128(x, 1);
|
583
|
-
ft_union tmp;
|
584
615
|
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
585
616
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
586
617
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
587
|
-
|
588
|
-
return tmp.f;
|
618
|
+
return ((v4f32)res)[0];
|
589
619
|
}
|
590
620
|
|
591
621
|
// horizontally add 8 int32_t
|
@@ -661,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
|
|
661
691
|
|
662
692
|
// multiply int8_t, add results pairwise twice and return as float vector
|
663
693
|
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
664
|
-
|
665
|
-
|
666
|
-
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
667
|
-
// Sign the values of the y vectors
|
668
|
-
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
669
|
-
|
670
|
-
return mul_sum_us8_pairs_float(ax, sy);
|
694
|
+
const __m256i dot = lasx_madd_h_b(x, y);
|
695
|
+
return sum_i16_pairs_float(dot);
|
671
696
|
}
|
672
697
|
|
673
698
|
static inline __m128i packNibbles( __m256i bytes ) {
|
@@ -747,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
|
747
772
|
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
|
748
773
|
}
|
749
774
|
}
|
750
|
-
#elif defined
|
775
|
+
#elif defined __wasm_simd128__
|
751
776
|
for (int i = 0; i < nb; i++) {
|
752
777
|
v128_t srcv [8];
|
753
778
|
v128_t asrcv[8];
|
@@ -927,7 +952,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
|
927
952
|
|
928
953
|
#elif defined(__loongarch_asx)
|
929
954
|
for (int i = 0; i < nb; i++) {
|
930
|
-
ft_union fi;
|
931
955
|
__m256 v0 = (__m256)__lasx_xvld( x , 0);
|
932
956
|
__m256 v1 = (__m256)__lasx_xvld( x , 32);
|
933
957
|
__m256 v2 = (__m256)__lasx_xvld( x , 64);
|
@@ -945,8 +969,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
|
945
969
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
946
970
|
__m128 tmp = max4;
|
947
971
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
|
948
|
-
|
949
|
-
const float max_scalar = fi.f;
|
972
|
+
const float max_scalar = ((v4f32)max4)[0];
|
950
973
|
|
951
974
|
// Quantize these floats
|
952
975
|
const float d = max_scalar / 127.f;
|
@@ -1037,7 +1060,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
|
1037
1060
|
|
1038
1061
|
y[i].s = LM_GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
|
1039
1062
|
}
|
1040
|
-
#elif defined
|
1063
|
+
#elif defined __wasm_simd128__
|
1041
1064
|
for (int i = 0; i < nb; i++) {
|
1042
1065
|
v128_t srcv [8];
|
1043
1066
|
v128_t asrcv[8];
|
@@ -1251,7 +1274,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
|
1251
1274
|
|
1252
1275
|
#elif defined(__loongarch_asx)
|
1253
1276
|
for (int i = 0; i < nb; i++) {
|
1254
|
-
ft_union ft;
|
1255
1277
|
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
|
1256
1278
|
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
|
1257
1279
|
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
|
@@ -1269,8 +1291,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
|
1269
1291
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
1270
1292
|
__m128 tmp = max4;
|
1271
1293
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
|
1272
|
-
|
1273
|
-
const float max_scalar = ft.f;
|
1294
|
+
const float max_scalar = ((v4f32)max4)[0];
|
1274
1295
|
|
1275
1296
|
// Quantize these floats
|
1276
1297
|
const float d = max_scalar / 127.f;
|
@@ -1653,7 +1674,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
|
|
1653
1674
|
//===================================== Q8_K ==============================================
|
1654
1675
|
|
1655
1676
|
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
1677
|
+
#ifdef __wasm_simd128__
|
1678
|
+
assert(k % QK_K == 0);
|
1679
|
+
const int64_t nb = k / QK_K;
|
1680
|
+
block_q8_K * restrict yc = y; // Cast to proper type
|
1681
|
+
|
1682
|
+
for (int i = 0; i < nb; i++) {
|
1683
|
+
const float * x_block = x + i * QK_K;
|
1684
|
+
|
1685
|
+
v128_t min_vec = wasm_v128_load(x_block);
|
1686
|
+
v128_t max_vec = min_vec;
|
1687
|
+
|
1688
|
+
for (int j = 4; j < QK_K; j += 4) {
|
1689
|
+
v128_t x_vec = wasm_v128_load(x_block + j);
|
1690
|
+
max_vec = wasm_f32x4_pmax(max_vec, x_vec);
|
1691
|
+
min_vec = wasm_f32x4_pmin(min_vec, x_vec);
|
1692
|
+
}
|
1693
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
|
1694
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
|
1695
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
|
1696
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
|
1697
|
+
float max = wasm_f32x4_extract_lane(max_vec, 0);
|
1698
|
+
float min = wasm_f32x4_extract_lane(min_vec, 0);
|
1699
|
+
float amax = -min > max ? min : max;
|
1700
|
+
|
1701
|
+
if (amax == 0.0f) {
|
1702
|
+
yc[i].d = 0.0f;
|
1703
|
+
const v128_t zero = wasm_i8x16_splat(0);
|
1704
|
+
for (int j = 0; j < QK_K; j += 16) {
|
1705
|
+
wasm_v128_store(yc[i].qs + j, zero);
|
1706
|
+
}
|
1707
|
+
continue;
|
1708
|
+
}
|
1709
|
+
|
1710
|
+
const float iscale = -127.0f / amax;
|
1711
|
+
const v128_t scale_vec = wasm_f32x4_splat(iscale);
|
1712
|
+
|
1713
|
+
// Process 16 elements per iteration
|
1714
|
+
for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
|
1715
|
+
// Load and quantize 16 floats
|
1716
|
+
v128_t x0 = wasm_v128_load(x_block + j);
|
1717
|
+
v128_t x1 = wasm_v128_load(x_block + j + 4);
|
1718
|
+
v128_t x2 = wasm_v128_load(x_block + j + 8);
|
1719
|
+
v128_t x3 = wasm_v128_load(x_block + j + 12);
|
1720
|
+
|
1721
|
+
v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
|
1722
|
+
v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
|
1723
|
+
v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
|
1724
|
+
v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
|
1725
|
+
|
1726
|
+
// Convert to i32 with saturation
|
1727
|
+
v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
|
1728
|
+
v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
|
1729
|
+
v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
|
1730
|
+
v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
|
1731
|
+
|
1732
|
+
// Pack into 16 i8 values
|
1733
|
+
v128_t i8 = wasm_i8x16_narrow_i16x8(
|
1734
|
+
wasm_i16x8_narrow_i32x4(i0, i1),
|
1735
|
+
wasm_i16x8_narrow_i32x4(i2, i3)
|
1736
|
+
);
|
1737
|
+
wasm_v128_store(yc[i].qs + j, i8);
|
1738
|
+
|
1739
|
+
// Calculate bsums using SIMD
|
1740
|
+
v128_t sum16 = wasm_i16x8_add(
|
1741
|
+
wasm_i16x8_extend_low_i8x16(i8),
|
1742
|
+
wasm_i16x8_extend_high_i8x16(i8)
|
1743
|
+
);
|
1744
|
+
v128_t sum32 = wasm_i32x4_add(
|
1745
|
+
wasm_i32x4_extend_low_i16x8(sum16),
|
1746
|
+
wasm_i32x4_extend_high_i16x8(sum16)
|
1747
|
+
);
|
1748
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
|
1749
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
|
1750
|
+
yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
|
1751
|
+
}
|
1752
|
+
|
1753
|
+
yc[i].d = 1.0f / iscale;
|
1754
|
+
}
|
1755
|
+
#else
|
1656
1756
|
quantize_row_q8_K_ref(x, y, k);
|
1757
|
+
#endif
|
1657
1758
|
}
|
1658
1759
|
|
1659
1760
|
//===================================== Dot products =================================
|
@@ -2011,6 +2112,94 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2011
2112
|
}
|
2012
2113
|
|
2013
2114
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2115
|
+
#elif defined __wasm_simd128__
|
2116
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
2117
|
+
|
2118
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
2119
|
+
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2120
|
+
|
2121
|
+
for (; ib + 1 < nb; ib += 2) {
|
2122
|
+
const block_q4_0 * restrict x0 = &x[ib];
|
2123
|
+
const block_q4_0 * restrict x1 = &x[ib + 1];
|
2124
|
+
const block_q8_0 * restrict y0 = &y[ib];
|
2125
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
2126
|
+
|
2127
|
+
// Load and process x0
|
2128
|
+
v128_t v0_0 = wasm_v128_load(x0->qs);
|
2129
|
+
v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
2130
|
+
v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
2131
|
+
v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
2132
|
+
v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
2133
|
+
|
2134
|
+
// Load y0 vectors
|
2135
|
+
v128_t y0_l = wasm_v128_load(y0->qs);
|
2136
|
+
v128_t y0_h = wasm_v128_load(y0->qs + 16);
|
2137
|
+
|
2138
|
+
// Extend to i16x8 and compute dot products
|
2139
|
+
v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
|
2140
|
+
v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
|
2141
|
+
v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
|
2142
|
+
v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
|
2143
|
+
|
2144
|
+
v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
|
2145
|
+
v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
|
2146
|
+
v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
|
2147
|
+
v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
|
2148
|
+
|
2149
|
+
v128_t dp0 = wasm_i32x4_add(
|
2150
|
+
wasm_i32x4_add(
|
2151
|
+
wasm_i32x4_dot_i16x8(dx0l, dy0ll),
|
2152
|
+
wasm_i32x4_dot_i16x8(dx0h, dy0lh)
|
2153
|
+
),
|
2154
|
+
wasm_i32x4_add(
|
2155
|
+
wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
|
2156
|
+
wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
|
2157
|
+
)
|
2158
|
+
);
|
2159
|
+
|
2160
|
+
// Load and process x1
|
2161
|
+
v128_t v0_1 = wasm_v128_load(x1->qs);
|
2162
|
+
v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
2163
|
+
v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
2164
|
+
v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
2165
|
+
v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
2166
|
+
|
2167
|
+
// Load y1 vectors
|
2168
|
+
v128_t y1_l = wasm_v128_load(y1->qs);
|
2169
|
+
v128_t y1_h = wasm_v128_load(y1->qs + 16);
|
2170
|
+
|
2171
|
+
// Extend to i16x8 and compute dot products
|
2172
|
+
v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
|
2173
|
+
v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
|
2174
|
+
v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
|
2175
|
+
v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
|
2176
|
+
|
2177
|
+
v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
|
2178
|
+
v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
|
2179
|
+
v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
|
2180
|
+
v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
|
2181
|
+
|
2182
|
+
v128_t dp1 = wasm_i32x4_add(
|
2183
|
+
wasm_i32x4_add(
|
2184
|
+
wasm_i32x4_dot_i16x8(dx1l, dy1ll),
|
2185
|
+
wasm_i32x4_dot_i16x8(dx1h, dy1lh)
|
2186
|
+
),
|
2187
|
+
wasm_i32x4_add(
|
2188
|
+
wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
|
2189
|
+
wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
|
2190
|
+
)
|
2191
|
+
);
|
2192
|
+
|
2193
|
+
// Accumulate results with scaling
|
2194
|
+
float scale0 = LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d);
|
2195
|
+
float scale1 = LM_GGML_FP16_TO_FP32(x1->d) * LM_GGML_FP16_TO_FP32(y1->d);
|
2196
|
+
|
2197
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
|
2198
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
|
2199
|
+
}
|
2200
|
+
|
2201
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
2202
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
2014
2203
|
#elif defined(__AVX2__)
|
2015
2204
|
// Initialize accumulator with zeros
|
2016
2205
|
__m256 acc = _mm256_setzero_ps();
|
@@ -2232,21 +2421,22 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2232
2421
|
}
|
2233
2422
|
|
2234
2423
|
sumf = hsum_float_8(acc);
|
2424
|
+
|
2235
2425
|
#elif defined(__loongarch_sx)
|
2236
2426
|
// set constants
|
2237
2427
|
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
|
2238
2428
|
const __m128i off = __lsx_vreplgr2vr_b(8);
|
2239
2429
|
|
2240
2430
|
// Initialize accumulator with zeros
|
2241
|
-
__m128 acc_0 = __lsx_vldi(0);
|
2242
|
-
__m128 acc_1 = __lsx_vldi(0);
|
2243
|
-
__m128 acc_2 = __lsx_vldi(0);
|
2244
|
-
__m128 acc_3 = __lsx_vldi(0);
|
2431
|
+
__m128 acc_0 = (__m128)__lsx_vldi(0);
|
2432
|
+
__m128 acc_1 = (__m128)__lsx_vldi(0);
|
2433
|
+
__m128 acc_2 = (__m128)__lsx_vldi(0);
|
2434
|
+
__m128 acc_3 = (__m128)__lsx_vldi(0);
|
2245
2435
|
|
2246
2436
|
for (; ib + 1 < nb; ib += 2) {
|
2247
2437
|
|
2248
2438
|
// Compute combined scale for the block 0 and 1
|
2249
|
-
const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) );
|
2439
|
+
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) );
|
2250
2440
|
|
2251
2441
|
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
2252
2442
|
|
@@ -2264,7 +2454,7 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2264
2454
|
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
2265
2455
|
|
2266
2456
|
// Compute combined scale for the block 2 and 3
|
2267
|
-
const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) );
|
2457
|
+
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) );
|
2268
2458
|
|
2269
2459
|
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
2270
2460
|
|
@@ -2696,10 +2886,10 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2696
2886
|
}
|
2697
2887
|
|
2698
2888
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2699
|
-
#elif defined
|
2889
|
+
#elif defined __wasm_simd128__
|
2700
2890
|
v128_t sumv = wasm_f32x4_splat(0.0f);
|
2701
2891
|
|
2702
|
-
uint32_t
|
2892
|
+
uint32_t qh_;
|
2703
2893
|
uint64_t tmp[4];
|
2704
2894
|
|
2705
2895
|
// TODO: check if unrolling this is better
|
@@ -2710,12 +2900,12 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2710
2900
|
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
2711
2901
|
|
2712
2902
|
// extract the 5th bit
|
2713
|
-
memcpy(&
|
2903
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
2714
2904
|
|
2715
|
-
tmp[0] = table_b2b_1[(
|
2716
|
-
tmp[1] = table_b2b_1[(
|
2717
|
-
tmp[2] = table_b2b_1[(
|
2718
|
-
tmp[3] = table_b2b_1[(
|
2905
|
+
tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
|
2906
|
+
tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
|
2907
|
+
tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
|
2908
|
+
tmp[3] = table_b2b_1[(qh_ >> 24) ];
|
2719
2909
|
|
2720
2910
|
const v128_t qhl = wasm_v128_load(tmp + 0);
|
2721
2911
|
const v128_t qhh = wasm_v128_load(tmp + 2);
|
@@ -3057,12 +3247,12 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
3057
3247
|
}
|
3058
3248
|
|
3059
3249
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
3060
|
-
#elif defined
|
3250
|
+
#elif defined __wasm_simd128__
|
3061
3251
|
v128_t sumv = wasm_f32x4_splat(0.0f);
|
3062
3252
|
|
3063
3253
|
float summs = 0.0f;
|
3064
3254
|
|
3065
|
-
uint32_t
|
3255
|
+
uint32_t qh_;
|
3066
3256
|
uint64_t tmp[4];
|
3067
3257
|
|
3068
3258
|
// TODO: check if unrolling this is better
|
@@ -3075,12 +3265,12 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
3075
3265
|
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
3076
3266
|
|
3077
3267
|
// extract the 5th bit
|
3078
|
-
memcpy(&
|
3268
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
3079
3269
|
|
3080
|
-
tmp[0] = table_b2b_0[(
|
3081
|
-
tmp[1] = table_b2b_0[(
|
3082
|
-
tmp[2] = table_b2b_0[(
|
3083
|
-
tmp[3] = table_b2b_0[(
|
3270
|
+
tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
|
3271
|
+
tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
|
3272
|
+
tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
|
3273
|
+
tmp[3] = table_b2b_0[(qh_ >> 24) ];
|
3084
3274
|
|
3085
3275
|
const v128_t qhl = wasm_v128_load(tmp + 0);
|
3086
3276
|
const v128_t qhh = wasm_v128_load(tmp + 2);
|
@@ -3573,6 +3763,45 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
3573
3763
|
}
|
3574
3764
|
|
3575
3765
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
3766
|
+
#elif defined __wasm_simd128__
|
3767
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
3768
|
+
|
3769
|
+
for (; ib < nb; ++ib) {
|
3770
|
+
const block_q8_0 * restrict x0 = &x[ib];
|
3771
|
+
const block_q8_0 * restrict y0 = &y[ib];
|
3772
|
+
|
3773
|
+
const v128_t x0_0 = wasm_v128_load(x0->qs);
|
3774
|
+
const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
|
3775
|
+
const v128_t y0_0 = wasm_v128_load(y0->qs);
|
3776
|
+
const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
|
3777
|
+
|
3778
|
+
// Extend 8-bit to 16-bit
|
3779
|
+
const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
|
3780
|
+
const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
|
3781
|
+
const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
|
3782
|
+
const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
|
3783
|
+
|
3784
|
+
const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
|
3785
|
+
const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
|
3786
|
+
const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
|
3787
|
+
const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
|
3788
|
+
|
3789
|
+
// Compute dot products
|
3790
|
+
const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
|
3791
|
+
const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
|
3792
|
+
const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
|
3793
|
+
const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
|
3794
|
+
|
3795
|
+
// Sum all dot products
|
3796
|
+
const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
|
3797
|
+
|
3798
|
+
// Convert to float and accumulate
|
3799
|
+
const float scale = LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d);
|
3800
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
|
3801
|
+
}
|
3802
|
+
|
3803
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
3804
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
3576
3805
|
#elif defined(__AVX2__)
|
3577
3806
|
// Initialize accumulator with zeros
|
3578
3807
|
__m256 acc = _mm256_setzero_ps();
|
@@ -4447,6 +4676,106 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
4447
4676
|
|
4448
4677
|
*s = hsum_float_8(acc);
|
4449
4678
|
|
4679
|
+
#elif defined __wasm_simd128__
|
4680
|
+
float sumf = 0;
|
4681
|
+
|
4682
|
+
for (int i = 0; i < nb; ++i) {
|
4683
|
+
const uint8_t * q2 = x[i].qs;
|
4684
|
+
const int8_t * q8 = y[i].qs;
|
4685
|
+
const uint8_t * sc = x[i].scales;
|
4686
|
+
|
4687
|
+
// Vectorized summs calculation
|
4688
|
+
v128_t summs_vec = wasm_i32x4_splat(0);
|
4689
|
+
{
|
4690
|
+
v128_t sc_vec = wasm_v128_load(sc);
|
4691
|
+
v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
|
4692
|
+
|
4693
|
+
v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
|
4694
|
+
v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
|
4695
|
+
|
4696
|
+
v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
|
4697
|
+
v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
|
4698
|
+
|
4699
|
+
summs_vec = wasm_i32x4_add(
|
4700
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
|
4701
|
+
wasm_i32x4_dot_i16x8(sc_high, bsums2)),
|
4702
|
+
summs_vec
|
4703
|
+
);
|
4704
|
+
|
4705
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
|
4706
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
|
4707
|
+
}
|
4708
|
+
int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
|
4709
|
+
|
4710
|
+
// Vectorized isum calculation
|
4711
|
+
int32_t isum = 0;
|
4712
|
+
const uint8_t * sc_ptr = sc;
|
4713
|
+
const int k_iters = QK_K/128;
|
4714
|
+
|
4715
|
+
for (int k = 0; k < k_iters; ++k) {
|
4716
|
+
v128_t isum_vec = wasm_i32x4_splat(0);
|
4717
|
+
int shift = 0;
|
4718
|
+
|
4719
|
+
for (int j = 0; j < 4; ++j) {
|
4720
|
+
const int d0 = (sc_ptr[0] & 0xF);
|
4721
|
+
const int d1 = (sc_ptr[1] & 0xF);
|
4722
|
+
sc_ptr += 2;
|
4723
|
+
|
4724
|
+
// Process first 16 elements
|
4725
|
+
v128_t q2_0 = wasm_v128_load(q2);
|
4726
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
4727
|
+
v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
|
4728
|
+
v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
|
4729
|
+
|
4730
|
+
// Process next 16 elements
|
4731
|
+
v128_t q2_1 = wasm_v128_load(q2 + 16);
|
4732
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
4733
|
+
v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
|
4734
|
+
v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
|
4735
|
+
|
4736
|
+
// Calculate dot products
|
4737
|
+
v128_t p0 = wasm_i32x4_dot_i16x8(
|
4738
|
+
wasm_i16x8_extend_low_i8x16(q8_0),
|
4739
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_0)
|
4740
|
+
);
|
4741
|
+
v128_t p1 = wasm_i32x4_dot_i16x8(
|
4742
|
+
wasm_i16x8_extend_high_i8x16(q8_0),
|
4743
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_0)
|
4744
|
+
);
|
4745
|
+
v128_t p2 = wasm_i32x4_dot_i16x8(
|
4746
|
+
wasm_i16x8_extend_low_i8x16(q8_1),
|
4747
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_1)
|
4748
|
+
);
|
4749
|
+
v128_t p3 = wasm_i32x4_dot_i16x8(
|
4750
|
+
wasm_i16x8_extend_high_i8x16(q8_1),
|
4751
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_1)
|
4752
|
+
);
|
4753
|
+
|
4754
|
+
// Accumulate scaled results
|
4755
|
+
v128_t scaled = wasm_i32x4_add(
|
4756
|
+
wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
|
4757
|
+
wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
|
4758
|
+
);
|
4759
|
+
|
4760
|
+
isum_vec = wasm_i32x4_add(isum_vec, scaled);
|
4761
|
+
q8 += 32;
|
4762
|
+
shift += 2;
|
4763
|
+
}
|
4764
|
+
q2 += 32;
|
4765
|
+
|
4766
|
+
// Horizontal sum of isum_vec
|
4767
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
|
4768
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
|
4769
|
+
isum += wasm_i32x4_extract_lane(isum_vec, 0);
|
4770
|
+
}
|
4771
|
+
|
4772
|
+
const float dall = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
4773
|
+
const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
4774
|
+
sumf += dall * isum - dmin * summs;
|
4775
|
+
}
|
4776
|
+
|
4777
|
+
*s = sumf;
|
4778
|
+
|
4450
4779
|
#elif defined __riscv_v_intrinsic
|
4451
4780
|
|
4452
4781
|
float sumf = 0;
|
@@ -4666,9 +4995,6 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
4666
4995
|
|
4667
4996
|
#elif defined __loongarch_asx
|
4668
4997
|
|
4669
|
-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
|
4670
|
-
const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
|
4671
|
-
|
4672
4998
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
4673
4999
|
|
4674
5000
|
for (int i = 0; i < nb; ++i) {
|
@@ -4679,18 +5005,15 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
4679
5005
|
const uint8_t * restrict q2 = x[i].qs;
|
4680
5006
|
const int8_t * restrict q8 = y[i].qs;
|
4681
5007
|
|
4682
|
-
const __m128i
|
4683
|
-
const __m128i
|
4684
|
-
const
|
4685
|
-
const __m256i mins = lasx_ext8_16(mins8);
|
5008
|
+
const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
|
5009
|
+
const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
|
5010
|
+
const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
|
4686
5011
|
const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
|
4687
5012
|
|
4688
5013
|
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
|
4689
5014
|
|
4690
|
-
const
|
4691
|
-
const
|
4692
|
-
const __m128i h_scales = lasx_extracti128(all_scales, 1);
|
4693
|
-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
5015
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
5016
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
4694
5017
|
|
4695
5018
|
__m256i sumi = __lasx_xvldi(0);
|
4696
5019
|
|
@@ -4703,20 +5026,20 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
4703
5026
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
4704
5027
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
4705
5028
|
|
4706
|
-
const __m256i q2_0 =
|
4707
|
-
const __m256i q2_1 =
|
4708
|
-
const __m256i q2_2 =
|
4709
|
-
const __m256i q2_3 =
|
5029
|
+
const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
|
5030
|
+
const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
|
5031
|
+
const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
|
5032
|
+
const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
|
4710
5033
|
|
4711
|
-
__m256i p0 =
|
4712
|
-
__m256i p1 =
|
4713
|
-
__m256i p2 =
|
4714
|
-
__m256i p3 =
|
5034
|
+
__m256i p0 = lasx_madd_h_b(q2_0, q8_0);
|
5035
|
+
__m256i p1 = lasx_madd_h_b(q2_1, q8_1);
|
5036
|
+
__m256i p2 = lasx_madd_h_b(q2_2, q8_2);
|
5037
|
+
__m256i p3 = lasx_madd_h_b(q2_3, q8_3);
|
4715
5038
|
|
4716
|
-
p0 = lasx_madd_h(
|
4717
|
-
p1 = lasx_madd_h(
|
4718
|
-
p2 = lasx_madd_h(
|
4719
|
-
p3 = lasx_madd_h(
|
5039
|
+
p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
|
5040
|
+
p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
|
5041
|
+
p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
|
5042
|
+
p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
|
4720
5043
|
|
4721
5044
|
p0 = __lasx_xvadd_w(p0, p1);
|
4722
5045
|
p2 = __lasx_xvadd_w(p2, p3);
|
@@ -5129,6 +5452,94 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5129
5452
|
|
5130
5453
|
*s = hsum_float_8(acc);
|
5131
5454
|
|
5455
|
+
#elif defined __wasm_simd128__
|
5456
|
+
int8_t aux8[QK_K];
|
5457
|
+
float sums[8] = {0};
|
5458
|
+
uint32_t auxs[4];
|
5459
|
+
|
5460
|
+
float sumf = 0;
|
5461
|
+
for (int i = 0; i < nb; ++i) {
|
5462
|
+
const uint8_t * restrict q3 = x[i].qs;
|
5463
|
+
const uint8_t * restrict hm = x[i].hmask;
|
5464
|
+
const int8_t * restrict q8 = y[i].qs;
|
5465
|
+
|
5466
|
+
// Process blocks with SIMD
|
5467
|
+
int8_t * a = aux8;
|
5468
|
+
uint8_t m = 1;
|
5469
|
+
for (int j = 0; j < QK_K; j += 128) {
|
5470
|
+
for (int shift = 0; shift <= 6; shift += 2) {
|
5471
|
+
v128_t v_m = wasm_i8x16_splat(m);
|
5472
|
+
for (int l = 0; l < 32; l += 16) {
|
5473
|
+
v128_t v_q3 = wasm_v128_load(q3 + l);
|
5474
|
+
v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
|
5475
|
+
v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
|
5476
|
+
|
5477
|
+
v128_t v_hm = wasm_v128_load(hm + l);
|
5478
|
+
v128_t v_mask = wasm_v128_and(v_hm, v_m);
|
5479
|
+
v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
|
5480
|
+
|
5481
|
+
v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
|
5482
|
+
wasm_v128_store(a + l, v_low2);
|
5483
|
+
}
|
5484
|
+
a += 32;
|
5485
|
+
m <<= 1;
|
5486
|
+
}
|
5487
|
+
q3 += 32;
|
5488
|
+
}
|
5489
|
+
|
5490
|
+
// Extract scales
|
5491
|
+
memcpy(auxs, x[i].scales, 12);
|
5492
|
+
uint32_t tmp = auxs[2];
|
5493
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
5494
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
5495
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
5496
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
5497
|
+
const int8_t * scales = (const int8_t *)auxs;
|
5498
|
+
|
5499
|
+
// SIMD dot product with register accumulators
|
5500
|
+
v128_t v_acc0 = wasm_i32x4_splat(0);
|
5501
|
+
v128_t v_acc1 = wasm_i32x4_splat(0);
|
5502
|
+
a = aux8;
|
5503
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
5504
|
+
const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
|
5505
|
+
|
5506
|
+
// Process 16 elements per iteration
|
5507
|
+
for (int k = 0; k < 2; ++k) {
|
5508
|
+
const v128_t v_q8 = wasm_i16x8_load8x8(q8);
|
5509
|
+
const v128_t v_a = wasm_i16x8_load8x8(a);
|
5510
|
+
|
5511
|
+
v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
|
5512
|
+
v_prod = wasm_i16x8_mul(v_prod, v_scale);
|
5513
|
+
|
5514
|
+
v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
|
5515
|
+
v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
|
5516
|
+
|
5517
|
+
q8 += 8;
|
5518
|
+
a += 8;
|
5519
|
+
}
|
5520
|
+
}
|
5521
|
+
|
5522
|
+
// Accumulate results
|
5523
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
5524
|
+
const v128_t v_d = wasm_f32x4_splat(d);
|
5525
|
+
v128_t v_sum = wasm_f32x4_add(
|
5526
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
|
5527
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
|
5528
|
+
);
|
5529
|
+
|
5530
|
+
// Accumulate into sums vector
|
5531
|
+
wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
|
5532
|
+
}
|
5533
|
+
|
5534
|
+
// Horizontal sum
|
5535
|
+
v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
|
5536
|
+
sumf = wasm_f32x4_extract_lane(v_sum, 0) +
|
5537
|
+
wasm_f32x4_extract_lane(v_sum, 1) +
|
5538
|
+
wasm_f32x4_extract_lane(v_sum, 2) +
|
5539
|
+
wasm_f32x4_extract_lane(v_sum, 3);
|
5540
|
+
|
5541
|
+
*s = sumf;
|
5542
|
+
|
5132
5543
|
#elif defined __riscv_v_intrinsic
|
5133
5544
|
|
5134
5545
|
uint32_t aux[3];
|
@@ -5384,8 +5795,6 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5384
5795
|
|
5385
5796
|
#elif defined __loongarch_asx
|
5386
5797
|
|
5387
|
-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
|
5388
|
-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
5389
5798
|
const __m128i m32 = __lsx_vreplgr2vr_b(32);
|
5390
5799
|
|
5391
5800
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
@@ -5405,10 +5814,9 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5405
5814
|
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
5406
5815
|
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
5407
5816
|
scales128 = __lsx_vsub_b(scales128, m32);
|
5408
|
-
|
5409
|
-
const
|
5410
|
-
const
|
5411
|
-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
5817
|
+
|
5818
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
5819
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
5412
5820
|
|
5413
5821
|
// high bit
|
5414
5822
|
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
|
@@ -5416,35 +5824,23 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5416
5824
|
// integer accumulator
|
5417
5825
|
__m256i sumi = __lasx_xvldi(0);
|
5418
5826
|
|
5419
|
-
int bit = 0;
|
5420
|
-
int is = 0;
|
5421
|
-
__m256i xvbit;
|
5422
|
-
|
5423
|
-
|
5424
5827
|
for (int j = 0; j < QK_K/128; ++j) {
|
5425
5828
|
// load low 2 bits
|
5426
5829
|
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
|
5427
5830
|
|
5428
|
-
xvbit = __lasx_xvreplgr2vr_h(bit);
|
5429
5831
|
// prepare low and high bits
|
5430
|
-
const __m256i q3l_0 =
|
5431
|
-
const __m256i
|
5432
|
-
|
5433
|
-
|
5434
|
-
|
5435
|
-
const __m256i
|
5436
|
-
const __m256i
|
5437
|
-
|
5438
|
-
|
5439
|
-
|
5440
|
-
const __m256i
|
5441
|
-
const __m256i
|
5442
|
-
++bit;
|
5443
|
-
|
5444
|
-
xvbit = __lasx_xvreplgr2vr_h(bit);
|
5445
|
-
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
|
5446
|
-
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
5447
|
-
++bit;
|
5832
|
+
const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
|
5833
|
+
const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
|
5834
|
+
const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
|
5835
|
+
const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
|
5836
|
+
const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
|
5837
|
+
const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
|
5838
|
+
const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
|
5839
|
+
const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
|
5840
|
+
const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
|
5841
|
+
const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
|
5842
|
+
const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
|
5843
|
+
const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
|
5448
5844
|
|
5449
5845
|
// load Q8 quants
|
5450
5846
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
@@ -5452,29 +5848,16 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5452
5848
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
5453
5849
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
5454
5850
|
|
5455
|
-
|
5456
|
-
|
5457
|
-
|
5458
|
-
__m256i
|
5459
|
-
__m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
|
5460
|
-
__m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
|
5461
|
-
__m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
|
5462
|
-
|
5463
|
-
__m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
|
5464
|
-
__m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
|
5465
|
-
__m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
|
5466
|
-
__m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
|
5467
|
-
|
5468
|
-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
|
5469
|
-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
|
5470
|
-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
|
5471
|
-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
|
5851
|
+
__m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
|
5852
|
+
__m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
|
5853
|
+
__m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
|
5854
|
+
__m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
|
5472
5855
|
|
5473
5856
|
// multiply with scales
|
5474
|
-
p16_0 = lasx_madd_h(
|
5475
|
-
p16_1 = lasx_madd_h(
|
5476
|
-
p16_2 = lasx_madd_h(
|
5477
|
-
p16_3 = lasx_madd_h(
|
5857
|
+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
|
5858
|
+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
|
5859
|
+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
|
5860
|
+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
|
5478
5861
|
|
5479
5862
|
// accumulate
|
5480
5863
|
p16_0 = __lasx_xvadd_w(p16_0, p16_1);
|
@@ -5482,7 +5865,7 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5482
5865
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
|
5483
5866
|
}
|
5484
5867
|
// multiply with block scale and accumulate
|
5485
|
-
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc)
|
5868
|
+
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
|
5486
5869
|
}
|
5487
5870
|
|
5488
5871
|
*s = hsum_float_8(acc);
|
@@ -5654,7 +6037,7 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5654
6037
|
}
|
5655
6038
|
}
|
5656
6039
|
*s = sumf;
|
5657
|
-
#elif __ARM_NEON
|
6040
|
+
#elif defined __ARM_NEON
|
5658
6041
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
5659
6042
|
const int32x4_t mzero = vdupq_n_s32(0);
|
5660
6043
|
|
@@ -5717,6 +6100,107 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
5717
6100
|
|
5718
6101
|
*s = sumf;
|
5719
6102
|
|
6103
|
+
#elif defined __wasm_simd128__
|
6104
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
6105
|
+
float sumf = 0;
|
6106
|
+
|
6107
|
+
for (int i = 0; i < nb; ++i) {
|
6108
|
+
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
6109
|
+
const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
|
6110
|
+
|
6111
|
+
const uint8_t * restrict q4 = x[i].qs;
|
6112
|
+
const int8_t * restrict q8 = y[i].qs;
|
6113
|
+
|
6114
|
+
// Process scales and mins
|
6115
|
+
memcpy(utmp, x[i].scales, 12);
|
6116
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
6117
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
6118
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
6119
|
+
utmp[2] = uaux;
|
6120
|
+
utmp[0] &= kmask1;
|
6121
|
+
|
6122
|
+
// Sum mins * q8sums
|
6123
|
+
int32_t sumi = 0;
|
6124
|
+
const int16_t * restrict q8sums = y[i].bsums;
|
6125
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
6126
|
+
for (int j = 0; j < 16; j += 2) {
|
6127
|
+
sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
6128
|
+
}
|
6129
|
+
sumf -= dmin * sumi;
|
6130
|
+
|
6131
|
+
int32_t sumi1 = 0;
|
6132
|
+
int32_t sumi2 = 0;
|
6133
|
+
|
6134
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
6135
|
+
// Load 64 4-bit weights (32 bytes)
|
6136
|
+
const v128_t q4x0 = wasm_v128_load(q4);
|
6137
|
+
const v128_t q4x1 = wasm_v128_load(q4 + 16);
|
6138
|
+
q4 += 32;
|
6139
|
+
|
6140
|
+
// Split into low/high nibbles
|
6141
|
+
const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
|
6142
|
+
const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
|
6143
|
+
const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
|
6144
|
+
const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
|
6145
|
+
|
6146
|
+
// Load 64 8-bit values (64 bytes)
|
6147
|
+
const v128_t q8x0 = wasm_v128_load(q8);
|
6148
|
+
const v128_t q8x1 = wasm_v128_load(q8 + 16);
|
6149
|
+
const v128_t q8x2 = wasm_v128_load(q8 + 32);
|
6150
|
+
const v128_t q8x3 = wasm_v128_load(q8 + 48);
|
6151
|
+
q8 += 64;
|
6152
|
+
|
6153
|
+
// Low nibble products
|
6154
|
+
v128_t vacc1 = wasm_i32x4_dot_i16x8(
|
6155
|
+
wasm_i16x8_extend_low_i8x16(q4l0),
|
6156
|
+
wasm_i16x8_extend_low_i8x16(q8x0)
|
6157
|
+
);
|
6158
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
6159
|
+
wasm_i16x8_extend_high_i8x16(q4l0),
|
6160
|
+
wasm_i16x8_extend_high_i8x16(q8x0)
|
6161
|
+
));
|
6162
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
6163
|
+
wasm_i16x8_extend_low_i8x16(q4l1),
|
6164
|
+
wasm_i16x8_extend_low_i8x16(q8x1)
|
6165
|
+
));
|
6166
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
6167
|
+
wasm_i16x8_extend_high_i8x16(q4l1),
|
6168
|
+
wasm_i16x8_extend_high_i8x16(q8x1)
|
6169
|
+
));
|
6170
|
+
|
6171
|
+
// High nibble products
|
6172
|
+
v128_t vacc2 = wasm_i32x4_dot_i16x8(
|
6173
|
+
wasm_i16x8_extend_low_i8x16(q4h0),
|
6174
|
+
wasm_i16x8_extend_low_i8x16(q8x2)
|
6175
|
+
);
|
6176
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
6177
|
+
wasm_i16x8_extend_high_i8x16(q4h0),
|
6178
|
+
wasm_i16x8_extend_high_i8x16(q8x2)
|
6179
|
+
));
|
6180
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
6181
|
+
wasm_i16x8_extend_low_i8x16(q4h1),
|
6182
|
+
wasm_i16x8_extend_low_i8x16(q8x3)
|
6183
|
+
));
|
6184
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
6185
|
+
wasm_i16x8_extend_high_i8x16(q4h1),
|
6186
|
+
wasm_i16x8_extend_high_i8x16(q8x3)
|
6187
|
+
));
|
6188
|
+
|
6189
|
+
// Accumulate scaled results
|
6190
|
+
int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
|
6191
|
+
wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
|
6192
|
+
sumi1 += vacc1_sum * scales[2*j];
|
6193
|
+
|
6194
|
+
int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
|
6195
|
+
wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
|
6196
|
+
sumi2 += vacc2_sum * scales[2*j+1];
|
6197
|
+
}
|
6198
|
+
|
6199
|
+
sumf += d * (sumi1 + sumi2);
|
6200
|
+
}
|
6201
|
+
|
6202
|
+
*s = sumf;
|
6203
|
+
|
5720
6204
|
#elif defined __AVX2__
|
5721
6205
|
|
5722
6206
|
const __m256i m4 = _mm256_set1_epi8(0xF);
|
@@ -6074,11 +6558,6 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6074
6558
|
*s = vec_extract(vsumf0, 0);
|
6075
6559
|
|
6076
6560
|
#elif defined __loongarch_asx
|
6077
|
-
LM_GGML_UNUSED(kmask1);
|
6078
|
-
LM_GGML_UNUSED(kmask2);
|
6079
|
-
LM_GGML_UNUSED(kmask3);
|
6080
|
-
|
6081
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
6082
6561
|
|
6083
6562
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
6084
6563
|
__m128 acc_m = (__m128)__lsx_vldi(0);
|
@@ -6098,33 +6577,34 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6098
6577
|
const uint8_t * restrict q4 = x[i].qs;
|
6099
6578
|
const int8_t * restrict q8 = y[i].qs;
|
6100
6579
|
|
6101
|
-
const
|
6580
|
+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
|
6581
|
+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
|
6582
|
+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
|
6102
6583
|
|
6103
6584
|
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
|
6104
6585
|
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
|
6105
|
-
const __m128i prod = lsx_madd_h(
|
6586
|
+
const __m128i prod = lsx_madd_h(mins128, q8s);
|
6106
6587
|
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
|
6107
6588
|
|
6108
|
-
const
|
6109
|
-
const __m256i scales = lasx_insertf128(sc128, sc128);
|
6589
|
+
const __m256i scales = lasx_insertf128(scales128, scales128);
|
6110
6590
|
|
6111
6591
|
__m256i sumi = __lasx_xvldi(0);
|
6112
6592
|
|
6113
6593
|
for (int j = 0; j < QK_K/64; ++j) {
|
6114
6594
|
|
6115
|
-
const __m256i scale_l =
|
6116
|
-
const __m256i scale_h =
|
6595
|
+
const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
|
6596
|
+
const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
|
6117
6597
|
|
6118
6598
|
const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
6119
|
-
const __m256i q4l =
|
6120
|
-
const __m256i q4h =
|
6599
|
+
const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
|
6600
|
+
const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
|
6121
6601
|
|
6122
6602
|
const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
6123
|
-
__m256i p16l =
|
6603
|
+
__m256i p16l = lasx_madd_h_b(q4l, q8l);
|
6124
6604
|
p16l = lasx_madd_h(scale_l, p16l);
|
6125
6605
|
|
6126
6606
|
const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
6127
|
-
__m256i p16h =
|
6607
|
+
__m256i p16h = lasx_madd_h_b(q4h, q8h);
|
6128
6608
|
p16h = lasx_madd_h(scale_h, p16h);
|
6129
6609
|
const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
|
6130
6610
|
|
@@ -6141,9 +6621,7 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6141
6621
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
|
6142
6622
|
|
6143
6623
|
|
6144
|
-
|
6145
|
-
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
|
6146
|
-
*s = hsum_float_8(acc) + fi.f ;
|
6624
|
+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
6147
6625
|
#else
|
6148
6626
|
|
6149
6627
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
@@ -6469,6 +6947,118 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6469
6947
|
|
6470
6948
|
*s = hsum_float_8(acc) + summs;
|
6471
6949
|
|
6950
|
+
#elif defined __wasm_simd128__
|
6951
|
+
//const uint8_t * scales = (const uint8_t*)&utmp[0];
|
6952
|
+
float sumf = 0;
|
6953
|
+
|
6954
|
+
for (int i = 0; i < nb; ++i) {
|
6955
|
+
const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
|
6956
|
+
const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
|
6957
|
+
|
6958
|
+
const uint8_t * restrict q5 = x[i].qs;
|
6959
|
+
const uint8_t * restrict qh = x[i].qh;
|
6960
|
+
const int8_t * restrict q8 = y[i].qs;
|
6961
|
+
|
6962
|
+
// Process scales and mins
|
6963
|
+
memcpy(utmp, x[i].scales, 12);
|
6964
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
6965
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
6966
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
6967
|
+
utmp[2] = uaux;
|
6968
|
+
utmp[0] &= kmask1;
|
6969
|
+
|
6970
|
+
// Sum mins * q8sums
|
6971
|
+
int32_t sumi_mins = 0;
|
6972
|
+
const int16_t * restrict q8sums = y[i].bsums;
|
6973
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
6974
|
+
for (int j = 0; j < 16; j += 2) {
|
6975
|
+
sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
6976
|
+
}
|
6977
|
+
sumf -= dmin * sumi_mins; // Correct subtraction
|
6978
|
+
|
6979
|
+
v128_t qh0 = wasm_v128_load(qh);
|
6980
|
+
v128_t qh1 = wasm_v128_load(qh + 16);
|
6981
|
+
const uint8_t * sc = (const uint8_t *)utmp;
|
6982
|
+
|
6983
|
+
int32_t sumi = 0;
|
6984
|
+
|
6985
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
6986
|
+
const int shift = j * 2;
|
6987
|
+
v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
|
6988
|
+
v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
|
6989
|
+
|
6990
|
+
v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
|
6991
|
+
v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
|
6992
|
+
v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
|
6993
|
+
v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
|
6994
|
+
|
6995
|
+
v128_t q5_0 = wasm_v128_load(q5);
|
6996
|
+
v128_t q5_1 = wasm_v128_load(q5 + 16);
|
6997
|
+
q5 += 32;
|
6998
|
+
|
6999
|
+
v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
|
7000
|
+
v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
|
7001
|
+
v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
|
7002
|
+
v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
|
7003
|
+
|
7004
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
7005
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
7006
|
+
v128_t q8_2 = wasm_v128_load(q8 + 32);
|
7007
|
+
v128_t q8_3 = wasm_v128_load(q8 + 48);
|
7008
|
+
q8 += 64;
|
7009
|
+
|
7010
|
+
// Process low quants
|
7011
|
+
v128_t pl0 = wasm_i32x4_dot_i16x8(
|
7012
|
+
wasm_i16x8_extend_low_i8x16(q5l_0),
|
7013
|
+
wasm_i16x8_extend_low_i8x16(q8_0)
|
7014
|
+
);
|
7015
|
+
pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
|
7016
|
+
wasm_i16x8_extend_high_i8x16(q5l_0),
|
7017
|
+
wasm_i16x8_extend_high_i8x16(q8_0)
|
7018
|
+
));
|
7019
|
+
v128_t pl1 = wasm_i32x4_dot_i16x8(
|
7020
|
+
wasm_i16x8_extend_low_i8x16(q5l_1),
|
7021
|
+
wasm_i16x8_extend_low_i8x16(q8_1)
|
7022
|
+
);
|
7023
|
+
pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
|
7024
|
+
wasm_i16x8_extend_high_i8x16(q5l_1),
|
7025
|
+
wasm_i16x8_extend_high_i8x16(q8_1)
|
7026
|
+
));
|
7027
|
+
v128_t sum_low = wasm_i32x4_add(pl0, pl1);
|
7028
|
+
|
7029
|
+
// Process high quants
|
7030
|
+
v128_t ph0 = wasm_i32x4_dot_i16x8(
|
7031
|
+
wasm_i16x8_extend_low_i8x16(q5h_0),
|
7032
|
+
wasm_i16x8_extend_low_i8x16(q8_2)
|
7033
|
+
);
|
7034
|
+
ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
|
7035
|
+
wasm_i16x8_extend_high_i8x16(q5h_0),
|
7036
|
+
wasm_i16x8_extend_high_i8x16(q8_2)
|
7037
|
+
));
|
7038
|
+
v128_t ph1 = wasm_i32x4_dot_i16x8(
|
7039
|
+
wasm_i16x8_extend_low_i8x16(q5h_1),
|
7040
|
+
wasm_i16x8_extend_low_i8x16(q8_3)
|
7041
|
+
);
|
7042
|
+
ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
|
7043
|
+
wasm_i16x8_extend_high_i8x16(q5h_1),
|
7044
|
+
wasm_i16x8_extend_high_i8x16(q8_3)
|
7045
|
+
));
|
7046
|
+
v128_t sum_high = wasm_i32x4_add(ph0, ph1);
|
7047
|
+
|
7048
|
+
// Accumulate with scale factors
|
7049
|
+
int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
|
7050
|
+
wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
|
7051
|
+
int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
|
7052
|
+
wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
|
7053
|
+
|
7054
|
+
sumi += sl * sc[2*j] + sh * sc[2*j+1];
|
7055
|
+
}
|
7056
|
+
|
7057
|
+
sumf += d * sumi;
|
7058
|
+
}
|
7059
|
+
|
7060
|
+
*s = sumf;
|
7061
|
+
|
6472
7062
|
#elif defined __riscv_v_intrinsic
|
6473
7063
|
|
6474
7064
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
@@ -6691,19 +7281,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6691
7281
|
*s = vec_extract(vsumf0, 0);
|
6692
7282
|
|
6693
7283
|
#elif defined __loongarch_asx
|
6694
|
-
LM_GGML_UNUSED(kmask1);
|
6695
|
-
LM_GGML_UNUSED(kmask2);
|
6696
|
-
LM_GGML_UNUSED(kmask3);
|
6697
|
-
|
6698
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
6699
|
-
const __m128i mzero = __lsx_vldi(0);
|
6700
|
-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
6701
7284
|
|
6702
7285
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
7286
|
+
__m128 acc_m = (__m128)__lsx_vldi(0);
|
6703
7287
|
|
6704
|
-
|
6705
|
-
|
6706
|
-
for (int i = 0; i < nb; ++i) {
|
7288
|
+
for (int i = 0; i < nb; ++i) {
|
6707
7289
|
|
6708
7290
|
const uint8_t * restrict q5 = x[i].qs;
|
6709
7291
|
const int8_t * restrict q8 = y[i].qs;
|
@@ -6718,49 +7300,40 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6718
7300
|
utmp[2] = uaux;
|
6719
7301
|
utmp[0] &= kmask1;
|
6720
7302
|
|
6721
|
-
const
|
7303
|
+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
|
7304
|
+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
|
7305
|
+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
|
6722
7306
|
|
6723
7307
|
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
|
6724
7308
|
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
|
6725
|
-
const __m128i prod = lsx_madd_h(
|
6726
|
-
|
6727
|
-
summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
|
7309
|
+
const __m128i prod = lsx_madd_h(mins128, q8s);
|
7310
|
+
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
|
6728
7311
|
|
6729
|
-
const
|
6730
|
-
const __m256i scales = lasx_insertf128(sc128, sc128);
|
7312
|
+
const __m256i scales = lasx_insertf128(scales128, scales128);
|
6731
7313
|
|
6732
7314
|
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
|
6733
|
-
__m256i hmask = mone;
|
6734
7315
|
|
6735
7316
|
__m256i sumi = __lasx_xvldi(0);
|
6736
7317
|
|
6737
|
-
int bit = 0;
|
6738
|
-
__m256i xvbit;
|
6739
|
-
|
6740
7318
|
for (int j = 0; j < QK_K/64; ++j) {
|
6741
7319
|
|
6742
|
-
const __m256i scale_0 =
|
6743
|
-
const __m256i scale_1 =
|
7320
|
+
const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
|
7321
|
+
const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
|
6744
7322
|
|
6745
7323
|
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
|
6746
7324
|
|
6747
|
-
|
6748
|
-
const __m256i
|
6749
|
-
const __m256i q5h_0 =
|
6750
|
-
const __m256i
|
6751
|
-
|
6752
|
-
|
6753
|
-
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
6754
|
-
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
|
6755
|
-
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
6756
|
-
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
|
6757
|
-
hmask = __lasx_xvslli_h(hmask, 1);
|
7325
|
+
const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
|
7326
|
+
const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
|
7327
|
+
const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
|
7328
|
+
const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
|
7329
|
+
const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
|
7330
|
+
const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
|
6758
7331
|
|
6759
7332
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
6760
7333
|
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
6761
7334
|
|
6762
|
-
__m256i p16_0 =
|
6763
|
-
__m256i p16_1 =
|
7335
|
+
__m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
|
7336
|
+
__m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
|
6764
7337
|
|
6765
7338
|
p16_0 = lasx_madd_h(scale_0, p16_0);
|
6766
7339
|
p16_1 = lasx_madd_h(scale_1, p16_1);
|
@@ -6774,7 +7347,10 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
6774
7347
|
|
6775
7348
|
}
|
6776
7349
|
|
6777
|
-
|
7350
|
+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
|
7351
|
+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
|
7352
|
+
|
7353
|
+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
6778
7354
|
|
6779
7355
|
#else
|
6780
7356
|
|
@@ -7132,6 +7708,85 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
7132
7708
|
|
7133
7709
|
*s = hsum_float_8(acc);
|
7134
7710
|
|
7711
|
+
#elif defined __wasm_simd128__
|
7712
|
+
int8_t aux8[QK_K] __attribute__((aligned(16)));
|
7713
|
+
int32_t aux32[8] __attribute__((aligned(16))) = {0};
|
7714
|
+
float sums[8] __attribute__((aligned(16))) = {0};
|
7715
|
+
|
7716
|
+
for (int i = 0; i < nb; ++i) {
|
7717
|
+
// Unpack 6-bit quantized data into aux8 (unchanged)
|
7718
|
+
const uint8_t * restrict q4 = x[i].ql;
|
7719
|
+
const uint8_t * restrict qh = x[i].qh;
|
7720
|
+
int8_t * a = aux8;
|
7721
|
+
for (int j = 0; j < QK_K; j += 128) {
|
7722
|
+
for (int l = 0; l < 32; ++l) {
|
7723
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
7724
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
7725
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
7726
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
7727
|
+
}
|
7728
|
+
a += 128;
|
7729
|
+
q4 += 64;
|
7730
|
+
qh += 32;
|
7731
|
+
}
|
7732
|
+
|
7733
|
+
const int8_t * restrict a_ptr = aux8;
|
7734
|
+
const int8_t * restrict q8 = y[i].qs;
|
7735
|
+
v128_t acc0 = wasm_i32x4_splat(0);
|
7736
|
+
v128_t acc1 = wasm_i32x4_splat(0);
|
7737
|
+
|
7738
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
7739
|
+
const int scale = x[i].scales[j];
|
7740
|
+
const v128_t vscale = wasm_i32x4_splat(scale);
|
7741
|
+
|
7742
|
+
// Load 16 elements from a and q8
|
7743
|
+
const v128_t a_vec = wasm_v128_load(a_ptr);
|
7744
|
+
const v128_t q8_vec = wasm_v128_load(q8);
|
7745
|
+
|
7746
|
+
// Process low 8 elements
|
7747
|
+
v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
|
7748
|
+
v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
|
7749
|
+
v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
|
7750
|
+
v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
|
7751
|
+
v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
|
7752
|
+
|
7753
|
+
// Process high 8 elements
|
7754
|
+
v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
|
7755
|
+
v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
|
7756
|
+
v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
|
7757
|
+
v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
|
7758
|
+
v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
|
7759
|
+
|
7760
|
+
// Scale and accumulate
|
7761
|
+
prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
|
7762
|
+
prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
|
7763
|
+
prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
|
7764
|
+
prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
|
7765
|
+
|
7766
|
+
acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
|
7767
|
+
acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
|
7768
|
+
|
7769
|
+
a_ptr += 16;
|
7770
|
+
q8 += 16;
|
7771
|
+
}
|
7772
|
+
|
7773
|
+
// Store accumulated results
|
7774
|
+
wasm_v128_store(&aux32[0], acc0);
|
7775
|
+
wasm_v128_store(&aux32[4], acc1);
|
7776
|
+
|
7777
|
+
const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
7778
|
+
for (int l = 0; l < 8; ++l) {
|
7779
|
+
sums[l] += d * aux32[l];
|
7780
|
+
}
|
7781
|
+
}
|
7782
|
+
|
7783
|
+
// Sum final results
|
7784
|
+
float sumf = 0;
|
7785
|
+
for (int l = 0; l < 8; ++l) {
|
7786
|
+
sumf += sums[l];
|
7787
|
+
}
|
7788
|
+
*s = sumf;
|
7789
|
+
|
7135
7790
|
#elif defined __riscv_v_intrinsic
|
7136
7791
|
|
7137
7792
|
float sumf = 0;
|
@@ -7356,8 +8011,6 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
7356
8011
|
|
7357
8012
|
#elif defined __loongarch_asx
|
7358
8013
|
|
7359
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
7360
|
-
const __m256i m2 = __lasx_xvreplgr2vr_b(3);
|
7361
8014
|
const __m256i m32s = __lasx_xvreplgr2vr_b(32);
|
7362
8015
|
|
7363
8016
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
@@ -7370,58 +8023,42 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
7370
8023
|
const uint8_t * restrict qh = x[i].qh;
|
7371
8024
|
const int8_t * restrict q8 = y[i].qs;
|
7372
8025
|
|
7373
|
-
const __m128i
|
8026
|
+
const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
|
8027
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
8028
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
7374
8029
|
|
7375
8030
|
__m256i sumi = __lasx_xvldi(0);
|
7376
8031
|
|
7377
|
-
int is = 0;
|
7378
|
-
|
7379
8032
|
for (int j = 0; j < QK_K/128; ++j) {
|
7380
8033
|
|
7381
|
-
const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
|
7382
|
-
const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
|
7383
|
-
const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
|
7384
|
-
const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
|
7385
|
-
is += 4;
|
7386
|
-
|
7387
8034
|
const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
7388
8035
|
const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
7389
8036
|
const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
|
7390
8037
|
|
7391
|
-
const __m256i q4h_0 =
|
7392
|
-
const __m256i q4h_1 =
|
7393
|
-
const __m256i q4h_2 =
|
7394
|
-
const __m256i q4h_3 =
|
8038
|
+
const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
|
8039
|
+
const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
|
8040
|
+
const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
|
8041
|
+
const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
|
7395
8042
|
|
7396
|
-
const __m256i q4_0 = __lasx_xvor_v(
|
7397
|
-
const __m256i q4_1 = __lasx_xvor_v(
|
7398
|
-
const __m256i q4_2 = __lasx_xvor_v(
|
7399
|
-
const __m256i q4_3 = __lasx_xvor_v(
|
8043
|
+
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
|
8044
|
+
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
|
8045
|
+
const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
|
8046
|
+
const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
|
7400
8047
|
|
7401
8048
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
7402
8049
|
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
7403
8050
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
7404
8051
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
7405
8052
|
|
7406
|
-
__m256i
|
7407
|
-
__m256i
|
7408
|
-
__m256i
|
7409
|
-
__m256i
|
8053
|
+
__m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
|
8054
|
+
__m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
|
8055
|
+
__m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
|
8056
|
+
__m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
|
7410
8057
|
|
7411
|
-
|
7412
|
-
|
7413
|
-
|
7414
|
-
|
7415
|
-
|
7416
|
-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
|
7417
|
-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
|
7418
|
-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
|
7419
|
-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
|
7420
|
-
|
7421
|
-
p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
|
7422
|
-
p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
|
7423
|
-
p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
|
7424
|
-
p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
|
8058
|
+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
|
8059
|
+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
|
8060
|
+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
|
8061
|
+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
|
7425
8062
|
|
7426
8063
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
|
7427
8064
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
|
@@ -9746,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
|
9746
10383
|
}
|
9747
10384
|
#elif defined(__loongarch_asx)
|
9748
10385
|
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
9749
|
-
const __m256i
|
9750
|
-
const __m256i
|
9751
|
-
|
9752
|
-
tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
|
9753
|
-
tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
|
9754
|
-
tmp3 = __lasx_xvadd_h(tmp1, tmp2);
|
9755
|
-
return __lasx_xvsat_h(tmp3, 15);
|
10386
|
+
const __m256i a = __lasx_xvmulwev_h_b(x, y);
|
10387
|
+
const __m256i b = __lasx_xvmulwod_h_b(x, y);
|
10388
|
+
return __lasx_xvadd_h(a, b);
|
9756
10389
|
}
|
9757
10390
|
#endif
|
9758
10391
|
|
@@ -10802,67 +11435,31 @@ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const voi
|
|
10802
11435
|
#elif defined(__loongarch_asx)
|
10803
11436
|
|
10804
11437
|
const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
|
10805
|
-
const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
|
10806
11438
|
|
10807
11439
|
__m256 accum = (__m256)__lasx_xvldi(0);
|
10808
|
-
__m256i tmp1;
|
10809
|
-
__m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
|
10810
11440
|
|
10811
|
-
mask_8f = __lsx_vreplgr2vr_b(0x8f);
|
10812
11441
|
for (int ibl = 0; ibl < nb; ++ibl) {
|
10813
11442
|
const uint8_t * qs = x[ibl].qs;
|
10814
11443
|
const int8_t * q8 = y[ibl].qs;
|
10815
11444
|
uint16_t sh = x[ibl].scales_h;
|
10816
11445
|
__m256i sumi1 = __lasx_xvldi(0);
|
10817
11446
|
__m256i sumi2 = __lasx_xvldi(0);
|
10818
|
-
__m128i zero = __lsx_vldi(0);
|
10819
11447
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
10820
|
-
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);
|
10821
|
-
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);
|
11448
|
+
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
|
11449
|
+
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
|
10822
11450
|
const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
|
10823
11451
|
const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
|
10824
|
-
|
10825
|
-
|
10826
|
-
|
10827
|
-
|
10828
|
-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
|
10829
|
-
|
10830
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
|
10831
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
10832
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
10833
|
-
tmp4 = __lsx_vand_v(tmp0, mask);
|
10834
|
-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
10835
|
-
|
10836
|
-
const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
|
10837
|
-
|
10838
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
|
10839
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
10840
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
10841
|
-
tmp3 = __lsx_vand_v(tmp0, mask);
|
10842
|
-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
|
10843
|
-
|
10844
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
|
10845
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
10846
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
10847
|
-
tmp4 = __lsx_vand_v(tmp0, mask);
|
10848
|
-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
10849
|
-
|
10850
|
-
const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
|
10851
|
-
|
11452
|
+
const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
|
11453
|
+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
|
11454
|
+
const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
|
11455
|
+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
|
10852
11456
|
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
10853
11457
|
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
10854
11458
|
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
10855
11459
|
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
|
10856
11460
|
sh >>= 4;
|
10857
|
-
__m256i
|
10858
|
-
|
10859
|
-
tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
|
10860
|
-
tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
|
10861
|
-
const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
|
10862
|
-
tmp1 = __lasx_xvreplgr2vr_h(ls2);
|
10863
|
-
tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
|
10864
|
-
tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
|
10865
|
-
const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
|
11461
|
+
const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
|
11462
|
+
const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
|
10866
11463
|
sumi1 = __lasx_xvadd_w(p_1, sumi1);
|
10867
11464
|
sumi2 = __lasx_xvadd_w(p_2, sumi2);
|
10868
11465
|
}
|