cui-llama.rn 1.4.3 → 1.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/build-arm64/CMakeCache.txt +429 -0
  4. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +21 -21
  5. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +101 -0
  6. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  7. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  8. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +376 -0
  9. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +16 -0
  10. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +165 -0
  11. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +297 -0
  12. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +1 -0
  13. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +1 -0
  14. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +1 -0
  15. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +1 -0
  16. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +1 -0
  17. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +1 -0
  18. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +1 -0
  19. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +8 -0
  20. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +1 -0
  21. package/android/src/main/build-arm64/CMakeFiles/progress.marks +1 -0
  22. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  23. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +58 -0
  24. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  25. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +756 -0
  26. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  27. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +709 -0
  28. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  29. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +714 -0
  30. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  31. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +62 -0
  32. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  33. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +708 -0
  34. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  35. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +113 -0
  36. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  37. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +713 -0
  38. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  39. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +763 -0
  40. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  41. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +61 -0
  42. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  43. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +707 -0
  44. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  45. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +104 -0
  46. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  47. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +714 -0
  48. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  49. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +723 -0
  50. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +62 -0
  51. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +722 -0
  52. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +89 -0
  53. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +2 -0
  54. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +2 -0
  55. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +2 -0
  56. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +17 -0
  57. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +41 -0
  58. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +62 -0
  59. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +722 -0
  60. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +89 -0
  61. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +2 -0
  62. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +2 -0
  63. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +2 -0
  64. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +17 -0
  65. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +41 -0
  66. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +62 -0
  67. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +722 -0
  68. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +89 -0
  69. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +2 -0
  70. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +2 -0
  71. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +2 -0
  72. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +17 -0
  73. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +41 -0
  74. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +62 -0
  75. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +722 -0
  76. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +89 -0
  77. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +2 -0
  78. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +2 -0
  79. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +2 -0
  80. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +17 -0
  81. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +41 -0
  82. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +62 -0
  83. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +722 -0
  84. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +89 -0
  85. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +2 -0
  86. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +2 -0
  87. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +2 -0
  88. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +17 -0
  89. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +41 -0
  90. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +62 -0
  91. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +722 -0
  92. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +89 -0
  93. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +2 -0
  94. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +2 -0
  95. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +2 -0
  96. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +17 -0
  97. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +41 -0
  98. package/android/src/main/build-arm64/Makefile +1862 -0
  99. package/android/src/main/build-arm64/cmake_install.cmake +66 -0
  100. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  101. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  102. package/android/src/main/jni-utils.h +6 -0
  103. package/android/src/main/jni.cpp +287 -31
  104. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  105. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  106. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  107. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  108. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  109. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  110. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  111. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  112. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  113. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  114. package/cpp/chat-template.hpp +529 -0
  115. package/cpp/chat.cpp +1085 -0
  116. package/cpp/chat.hpp +55 -0
  117. package/cpp/common.cpp +159 -36
  118. package/cpp/common.h +64 -19
  119. package/cpp/ggml-alloc.c +1 -13
  120. package/cpp/ggml-common.h +0 -2
  121. package/cpp/ggml-cpu-impl.h +6 -12
  122. package/cpp/ggml-cpu-quants.c +937 -340
  123. package/cpp/ggml-cpu.c +207 -113
  124. package/cpp/ggml-cpu.cpp +4 -6
  125. package/cpp/ggml-cpu.h +1 -1
  126. package/cpp/ggml-metal.h +66 -66
  127. package/cpp/ggml-metal.m +141 -23
  128. package/cpp/ggml.c +24 -14
  129. package/cpp/ggml.h +2 -2
  130. package/cpp/json-schema-to-grammar.cpp +46 -66
  131. package/cpp/json-schema-to-grammar.h +15 -1
  132. package/cpp/llama-arch.cpp +7 -2
  133. package/cpp/llama-arch.h +3 -1
  134. package/cpp/llama-chat.cpp +10 -1
  135. package/cpp/llama-chat.h +1 -0
  136. package/cpp/llama-grammar.cpp +86 -6
  137. package/cpp/llama-grammar.h +22 -1
  138. package/cpp/llama-impl.h +6 -6
  139. package/cpp/llama-kv-cache.h +1 -1
  140. package/cpp/llama-mmap.h +1 -0
  141. package/cpp/llama-model-loader.cpp +1 -1
  142. package/cpp/llama-model.cpp +32 -6
  143. package/cpp/llama-sampling.cpp +178 -61
  144. package/cpp/llama-vocab.cpp +8 -3
  145. package/cpp/llama.cpp +188 -128
  146. package/cpp/llama.h +27 -10
  147. package/cpp/log.cpp +32 -10
  148. package/cpp/log.h +12 -1
  149. package/cpp/minja.hpp +2883 -0
  150. package/cpp/rn-llama.cpp +82 -5
  151. package/cpp/rn-llama.h +16 -1
  152. package/cpp/sampling.cpp +68 -41
  153. package/cpp/sampling.h +3 -0
  154. package/cpp/sgemm.cpp +9 -8
  155. package/cpp/unicode.cpp +9 -2
  156. package/ios/CMakeLists.txt +6 -0
  157. package/ios/RNLlama.h +0 -8
  158. package/ios/RNLlama.mm +27 -3
  159. package/ios/RNLlamaContext.h +10 -1
  160. package/ios/RNLlamaContext.mm +269 -57
  161. package/jest/mock.js +21 -2
  162. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  163. package/lib/commonjs/grammar.js +3 -0
  164. package/lib/commonjs/grammar.js.map +1 -1
  165. package/lib/commonjs/index.js +87 -13
  166. package/lib/commonjs/index.js.map +1 -1
  167. package/lib/module/NativeRNLlama.js.map +1 -1
  168. package/lib/module/grammar.js +3 -0
  169. package/lib/module/grammar.js.map +1 -1
  170. package/lib/module/index.js +86 -13
  171. package/lib/module/index.js.map +1 -1
  172. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  173. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  174. package/lib/typescript/grammar.d.ts.map +1 -1
  175. package/lib/typescript/index.d.ts +32 -7
  176. package/lib/typescript/index.d.ts.map +1 -1
  177. package/llama-rn.podspec +1 -1
  178. package/package.json +3 -2
  179. package/src/NativeRNLlama.ts +115 -3
  180. package/src/grammar.ts +3 -0
  181. package/src/index.ts +138 -21
@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
297
297
  static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
298
298
  #endif
299
299
 
300
+ #if defined(__loongarch_sx)
301
+
302
+ static __m128i lsx_packs_w(__m128i a, __m128i b) {
303
+ __m128i tmp, tmp1;
304
+ tmp = __lsx_vsat_w(a, 15);
305
+ tmp1 = __lsx_vsat_w(b, 15);
306
+ return __lsx_vpickev_h(tmp1, tmp);
307
+ }
308
+
309
+ static __m128i lsx_packs_h(__m128i a, __m128i b) {
310
+ __m128i tmp, tmp1;
311
+ tmp = __lsx_vsat_h(a, 7);
312
+ tmp1 = __lsx_vsat_h(b, 7);
313
+ return __lsx_vpickev_b(tmp1, tmp);
314
+ }
315
+
316
+ static __m128i lsx_packus_h(__m128i a, __m128i b) {
317
+ __m128i tmp, tmp1;
318
+ tmp = __lsx_vsat_hu(a, 7);
319
+ tmp1 = __lsx_vsat_hu(b, 7);
320
+ return __lsx_vpickev_b(tmp1, tmp);
321
+ }
322
+
323
+ static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
324
+ __m128i tmp1, tmp2;
325
+ tmp1 = __lsx_vmulwev_h_b(a, b);
326
+ tmp2 = __lsx_vmulwod_h_b(a, b);
327
+ return __lsx_vsadd_h(tmp1, tmp2);
328
+ }
329
+
330
+ static __m128i lsx_madd_h(__m128i a, __m128i b) {
331
+ __m128i tmp1, tmp2;
332
+ tmp1 = __lsx_vmulwev_w_h(a, b);
333
+ tmp2 = __lsx_vmulwod_w_h(a, b);
334
+ return __lsx_vadd_w(tmp1, tmp2);
335
+ }
336
+
337
+ static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
338
+ v4i32 __ret = {d, c, b, a};
339
+ return (__m128i)__ret;
340
+ }
341
+
342
+ static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
343
+ __m128i mask_f, zero, tmp0, tmp2, mask;
344
+ int f = 0x8f;
345
+ mask_f = __lsx_vreplgr2vr_b(f);
346
+ zero = __lsx_vldi(0);
347
+ tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
348
+ tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
349
+ mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
350
+ tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
351
+ return __lsx_vshuf_b(a, zero, tmp2);
352
+ }
353
+
354
+ static __m128i lsx_hadd_h(__m128i a, __m128i b) {
355
+ __m128i tmp1 = __lsx_vpickev_h(b, a);
356
+ __m128i tmp2 = __lsx_vpickod_h(b, a);
357
+ return __lsx_vadd_h(tmp1, tmp2);
358
+ }
359
+
360
+ static __m128i lsx_hadd_w(__m128i a, __m128i b) {
361
+ __m128i tmp1 = __lsx_vpickev_w(b, a);
362
+ __m128i tmp2 = __lsx_vpickod_w(b, a);
363
+ return __lsx_vadd_w(tmp1, tmp2);
364
+ }
365
+
366
+ static __m128 lsx_hadd_s(__m128 a, __m128 b) {
367
+ __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
368
+ __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
369
+
370
+ return __lsx_vfadd_s(tmp1, tmp2);
371
+ }
372
+
373
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
374
+ __m128 res_0 =lsx_hadd_s(a, b);
375
+ __m128 res_1 =lsx_hadd_s(c, d);
376
+ __m128 res =lsx_hadd_s(res_0, res_1);
377
+ res =lsx_hadd_s(res, res);
378
+ res =lsx_hadd_s(res, res);
379
+
380
+ return ((v4f32)res)[0];
381
+ }
382
+ #endif
383
+
300
384
  #if defined(__loongarch_asx)
301
385
 
302
386
  #ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
395
479
  return (__m256i)__ret;
396
480
  }
397
481
 
398
- static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
399
- v4i32 __ret = {d, c, b, a};
400
- return (__m128i)__ret;
401
- }
402
-
403
482
  static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
404
483
  v4i64 __ret = {d, c, b, a};
405
484
  return (__m256i)__ret;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
409
488
  return lasx_set_q(x, y);
410
489
  }
411
490
 
412
- static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
413
- __m128i mask_f, zero, tmp0, tmp2, mask;
414
- int f = 0x8f;
415
- mask_f = __lsx_vreplgr2vr_b(f);
416
- zero = __lsx_vldi(0);
417
- tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
418
- tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
419
- mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
420
- tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
421
- return __lsx_vshuf_b(a, zero, tmp2);
422
- }
423
-
424
491
  static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
425
492
  __m256i mask_f, zero, tmp0, tmp2, mask;
426
493
  int f = 0x8f;
@@ -434,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
434
501
  }
435
502
 
436
503
  static __m256i lasx_extu8_16(__m128i a) {
437
- __m128i zero = __lsx_vldi(0);
438
- __m128i vlo = __lsx_vilvl_b(zero, a);
439
- __m128i vhi = __lsx_vilvh_b(zero, a);
440
- return lasx_set_q(vhi, vlo);
504
+ return __lasx_vext2xv_hu_bu(____m256i(a));
441
505
  }
442
506
 
443
507
  static __m256i lasx_ext8_16(__m128i a) {
444
- __m128i sign = __lsx_vslti_b(a, 0);
445
- __m128i vlo = __lsx_vilvl_b(sign, a);
446
- __m128i vhi = __lsx_vilvh_b(sign, a);
447
- return lasx_set_q(vhi, vlo);
508
+ return __lasx_vext2xv_h_b(____m256i(a));
448
509
  }
449
510
 
450
511
  static __m256i lasx_ext16_32(__m128i a) {
451
- __m256i tmp1;
452
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
453
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
454
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
455
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
456
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
457
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
458
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
459
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
460
- return tmp1;
512
+ return __lasx_vext2xv_w_h(____m256i(a));
461
513
  }
462
514
 
463
515
  static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -482,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
482
534
  return ret;
483
535
  }
484
536
 
485
- static __m128i lsx_hadd_h(__m128i a, __m128i b) {
486
- __m128i tmp1 = __lsx_vpickev_h(b, a);
487
- __m128i tmp2 = __lsx_vpickod_h(b, a);
488
- return __lsx_vadd_h(tmp1, tmp2);
489
- }
490
-
491
- static __m128i lsx_hadd_w(__m128i a, __m128i b) {
492
- __m128i tmp1 = __lsx_vpickev_w(b, a);
493
- __m128i tmp2 = __lsx_vpickod_w(b, a);
494
- return __lsx_vadd_w(tmp1, tmp2);
495
- }
496
-
497
- static __m128 lsx_hadd_s(__m128 a, __m128 b) {
498
- __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
499
- __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
500
-
501
- return __lsx_vfadd_s(tmp1, tmp2);
502
- }
503
-
504
537
  static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
505
538
  __m256i tmp1, tmp2;
506
539
  tmp1 = __lasx_xvmulwev_h_b(a, b);
@@ -529,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
529
562
  return __lasx_xvpickev_b(tmp1, tmp);
530
563
  }
531
564
 
532
- static __m128i lsx_packs_w(__m128i a, __m128i b) {
533
- __m128i tmp, tmp1;
534
- tmp = __lsx_vsat_w(a, 15);
535
- tmp1 = __lsx_vsat_w(b, 15);
536
- return __lsx_vpickev_h(tmp1, tmp);
537
- }
538
-
539
- static __m128i lsx_packs_h(__m128i a, __m128i b) {
540
- __m128i tmp, tmp1;
541
- tmp = __lsx_vsat_h(a, 7);
542
- tmp1 = __lsx_vsat_h(b, 7);
543
- return __lsx_vpickev_b(tmp1, tmp);
544
- }
545
-
546
- static __m128i lsx_packus_h(__m128i a, __m128i b) {
547
- __m128i tmp, tmp1;
548
- tmp = __lsx_vsat_hu(a, 7);
549
- tmp1 = __lsx_vsat_hu(b, 7);
550
- return __lsx_vpickev_b(tmp1, tmp);
565
+ static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
566
+ __m256i tmp1, tmp2;
567
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
568
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
569
+ return __lasx_xvadd_h(tmp1, tmp2);
551
570
  }
552
571
 
553
-
554
- static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
555
- __m128i tmp1, tmp2;
556
- tmp1 = __lsx_vmulwev_h_b(a, b);
557
- tmp2 = __lsx_vmulwod_h_b(a, b);
558
- return __lsx_vsadd_h(tmp1, tmp2);
572
+ static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
573
+ switch (b) {
574
+ case 0: return __lasx_xvrepl128vei_h(a, 0);
575
+ case 1: return __lasx_xvrepl128vei_h(a, 1);
576
+ case 2: return __lasx_xvrepl128vei_h(a, 2);
577
+ case 3: return __lasx_xvrepl128vei_h(a, 3);
578
+ case 4: return __lasx_xvrepl128vei_h(a, 4);
579
+ case 5: return __lasx_xvrepl128vei_h(a, 5);
580
+ case 6: return __lasx_xvrepl128vei_h(a, 6);
581
+ case 7: return __lasx_xvrepl128vei_h(a, 7);
582
+ default: __builtin_unreachable();
583
+ }
559
584
  }
560
585
 
561
- static __m128i lsx_madd_h(__m128i a, __m128i b) {
562
- __m128i tmp1, tmp2;
563
- tmp1 = __lsx_vmulwev_w_h(a, b);
564
- tmp2 = __lsx_vmulwod_w_h(a, b);
565
- return __lsx_vadd_w(tmp1, tmp2);
586
+ static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
587
+ switch (b) {
588
+ case 0: return __lasx_xvandi_b(a, 1 << 0);
589
+ case 1: return __lasx_xvandi_b(a, 1 << 1);
590
+ case 2: return __lasx_xvandi_b(a, 1 << 2);
591
+ case 3: return __lasx_xvandi_b(a, 1 << 3);
592
+ case 4: return __lasx_xvandi_b(a, 1 << 4);
593
+ case 5: return __lasx_xvandi_b(a, 1 << 5);
594
+ case 6: return __lasx_xvandi_b(a, 1 << 6);
595
+ case 7: return __lasx_xvandi_b(a, 1 << 7);
596
+ default: __builtin_unreachable();
597
+ }
566
598
  }
567
599
 
568
600
  // multiply int8_t, add results pairwise twice
@@ -580,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
580
612
  // horizontally add 8 floats
581
613
  static inline float hsum_float_8(const __m256 x) {
582
614
  __m128 res = lasx_extractf128(x, 1);
583
- ft_union tmp;
584
615
  res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
585
616
  res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
586
617
  res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
587
- tmp.i = __lsx_vpickve2gr_w(res, 0);
588
- return tmp.f;
618
+ return ((v4f32)res)[0];
589
619
  }
590
620
 
591
621
  // horizontally add 8 int32_t
@@ -661,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
661
691
 
662
692
  // multiply int8_t, add results pairwise twice and return as float vector
663
693
  static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
664
-
665
- // Get absolute values of x vectors
666
- const __m256i ax = __lasx_xvsigncov_b(x, x);
667
- // Sign the values of the y vectors
668
- const __m256i sy = __lasx_xvsigncov_b(x, y);
669
-
670
- return mul_sum_us8_pairs_float(ax, sy);
694
+ const __m256i dot = lasx_madd_h_b(x, y);
695
+ return sum_i16_pairs_float(dot);
671
696
  }
672
697
 
673
698
  static inline __m128i packNibbles( __m256i bytes ) {
@@ -747,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
747
772
  y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
748
773
  }
749
774
  }
750
- #elif defined(__wasm_simd128__)
775
+ #elif defined __wasm_simd128__
751
776
  for (int i = 0; i < nb; i++) {
752
777
  v128_t srcv [8];
753
778
  v128_t asrcv[8];
@@ -927,7 +952,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
927
952
 
928
953
  #elif defined(__loongarch_asx)
929
954
  for (int i = 0; i < nb; i++) {
930
- ft_union fi;
931
955
  __m256 v0 = (__m256)__lasx_xvld( x , 0);
932
956
  __m256 v1 = (__m256)__lasx_xvld( x , 32);
933
957
  __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -945,8 +969,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
945
969
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
946
970
  __m128 tmp = max4;
947
971
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
948
- fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
949
- const float max_scalar = fi.f;
972
+ const float max_scalar = ((v4f32)max4)[0];
950
973
 
951
974
  // Quantize these floats
952
975
  const float d = max_scalar / 127.f;
@@ -1037,7 +1060,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1037
1060
 
1038
1061
  y[i].s = LM_GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
1039
1062
  }
1040
- #elif defined(__wasm_simd128__)
1063
+ #elif defined __wasm_simd128__
1041
1064
  for (int i = 0; i < nb; i++) {
1042
1065
  v128_t srcv [8];
1043
1066
  v128_t asrcv[8];
@@ -1251,7 +1274,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1251
1274
 
1252
1275
  #elif defined(__loongarch_asx)
1253
1276
  for (int i = 0; i < nb; i++) {
1254
- ft_union ft;
1255
1277
  __m256 v0 = (__m256)__lasx_xvld( x , 0 );
1256
1278
  __m256 v1 = (__m256)__lasx_xvld( x , 32 );
1257
1279
  __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1269,8 +1291,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1269
1291
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
1270
1292
  __m128 tmp = max4;
1271
1293
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
1272
- ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
1273
- const float max_scalar = ft.f;
1294
+ const float max_scalar = ((v4f32)max4)[0];
1274
1295
 
1275
1296
  // Quantize these floats
1276
1297
  const float d = max_scalar / 127.f;
@@ -1653,7 +1674,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
1653
1674
  //===================================== Q8_K ==============================================
1654
1675
 
1655
1676
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
1677
+ #ifdef __wasm_simd128__
1678
+ assert(k % QK_K == 0);
1679
+ const int64_t nb = k / QK_K;
1680
+ block_q8_K * restrict yc = y; // Cast to proper type
1681
+
1682
+ for (int i = 0; i < nb; i++) {
1683
+ const float * x_block = x + i * QK_K;
1684
+
1685
+ v128_t min_vec = wasm_v128_load(x_block);
1686
+ v128_t max_vec = min_vec;
1687
+
1688
+ for (int j = 4; j < QK_K; j += 4) {
1689
+ v128_t x_vec = wasm_v128_load(x_block + j);
1690
+ max_vec = wasm_f32x4_pmax(max_vec, x_vec);
1691
+ min_vec = wasm_f32x4_pmin(min_vec, x_vec);
1692
+ }
1693
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
1694
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
1695
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
1696
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
1697
+ float max = wasm_f32x4_extract_lane(max_vec, 0);
1698
+ float min = wasm_f32x4_extract_lane(min_vec, 0);
1699
+ float amax = -min > max ? min : max;
1700
+
1701
+ if (amax == 0.0f) {
1702
+ yc[i].d = 0.0f;
1703
+ const v128_t zero = wasm_i8x16_splat(0);
1704
+ for (int j = 0; j < QK_K; j += 16) {
1705
+ wasm_v128_store(yc[i].qs + j, zero);
1706
+ }
1707
+ continue;
1708
+ }
1709
+
1710
+ const float iscale = -127.0f / amax;
1711
+ const v128_t scale_vec = wasm_f32x4_splat(iscale);
1712
+
1713
+ // Process 16 elements per iteration
1714
+ for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
1715
+ // Load and quantize 16 floats
1716
+ v128_t x0 = wasm_v128_load(x_block + j);
1717
+ v128_t x1 = wasm_v128_load(x_block + j + 4);
1718
+ v128_t x2 = wasm_v128_load(x_block + j + 8);
1719
+ v128_t x3 = wasm_v128_load(x_block + j + 12);
1720
+
1721
+ v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
1722
+ v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
1723
+ v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
1724
+ v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
1725
+
1726
+ // Convert to i32 with saturation
1727
+ v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
1728
+ v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
1729
+ v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
1730
+ v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
1731
+
1732
+ // Pack into 16 i8 values
1733
+ v128_t i8 = wasm_i8x16_narrow_i16x8(
1734
+ wasm_i16x8_narrow_i32x4(i0, i1),
1735
+ wasm_i16x8_narrow_i32x4(i2, i3)
1736
+ );
1737
+ wasm_v128_store(yc[i].qs + j, i8);
1738
+
1739
+ // Calculate bsums using SIMD
1740
+ v128_t sum16 = wasm_i16x8_add(
1741
+ wasm_i16x8_extend_low_i8x16(i8),
1742
+ wasm_i16x8_extend_high_i8x16(i8)
1743
+ );
1744
+ v128_t sum32 = wasm_i32x4_add(
1745
+ wasm_i32x4_extend_low_i16x8(sum16),
1746
+ wasm_i32x4_extend_high_i16x8(sum16)
1747
+ );
1748
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
1749
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
1750
+ yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
1751
+ }
1752
+
1753
+ yc[i].d = 1.0f / iscale;
1754
+ }
1755
+ #else
1656
1756
  quantize_row_q8_K_ref(x, y, k);
1757
+ #endif
1657
1758
  }
1658
1759
 
1659
1760
  //===================================== Dot products =================================
@@ -2011,6 +2112,94 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2011
2112
  }
2012
2113
 
2013
2114
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2115
+ #elif defined __wasm_simd128__
2116
+ v128_t sumv = wasm_f32x4_splat(0.0f);
2117
+
2118
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2119
+ const v128_t s8b = wasm_i8x16_splat(0x8);
2120
+
2121
+ for (; ib + 1 < nb; ib += 2) {
2122
+ const block_q4_0 * restrict x0 = &x[ib];
2123
+ const block_q4_0 * restrict x1 = &x[ib + 1];
2124
+ const block_q8_0 * restrict y0 = &y[ib];
2125
+ const block_q8_0 * restrict y1 = &y[ib + 1];
2126
+
2127
+ // Load and process x0
2128
+ v128_t v0_0 = wasm_v128_load(x0->qs);
2129
+ v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2130
+ v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2131
+ v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2132
+ v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2133
+
2134
+ // Load y0 vectors
2135
+ v128_t y0_l = wasm_v128_load(y0->qs);
2136
+ v128_t y0_h = wasm_v128_load(y0->qs + 16);
2137
+
2138
+ // Extend to i16x8 and compute dot products
2139
+ v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
2140
+ v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
2141
+ v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
2142
+ v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
2143
+
2144
+ v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
2145
+ v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
2146
+ v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
2147
+ v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
2148
+
2149
+ v128_t dp0 = wasm_i32x4_add(
2150
+ wasm_i32x4_add(
2151
+ wasm_i32x4_dot_i16x8(dx0l, dy0ll),
2152
+ wasm_i32x4_dot_i16x8(dx0h, dy0lh)
2153
+ ),
2154
+ wasm_i32x4_add(
2155
+ wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
2156
+ wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
2157
+ )
2158
+ );
2159
+
2160
+ // Load and process x1
2161
+ v128_t v0_1 = wasm_v128_load(x1->qs);
2162
+ v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2163
+ v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2164
+ v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2165
+ v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2166
+
2167
+ // Load y1 vectors
2168
+ v128_t y1_l = wasm_v128_load(y1->qs);
2169
+ v128_t y1_h = wasm_v128_load(y1->qs + 16);
2170
+
2171
+ // Extend to i16x8 and compute dot products
2172
+ v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
2173
+ v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
2174
+ v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
2175
+ v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
2176
+
2177
+ v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
2178
+ v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
2179
+ v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
2180
+ v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
2181
+
2182
+ v128_t dp1 = wasm_i32x4_add(
2183
+ wasm_i32x4_add(
2184
+ wasm_i32x4_dot_i16x8(dx1l, dy1ll),
2185
+ wasm_i32x4_dot_i16x8(dx1h, dy1lh)
2186
+ ),
2187
+ wasm_i32x4_add(
2188
+ wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
2189
+ wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
2190
+ )
2191
+ );
2192
+
2193
+ // Accumulate results with scaling
2194
+ float scale0 = LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d);
2195
+ float scale1 = LM_GGML_FP16_TO_FP32(x1->d) * LM_GGML_FP16_TO_FP32(y1->d);
2196
+
2197
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
2198
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
2199
+ }
2200
+
2201
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
2202
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
2014
2203
  #elif defined(__AVX2__)
2015
2204
  // Initialize accumulator with zeros
2016
2205
  __m256 acc = _mm256_setzero_ps();
@@ -2232,21 +2421,22 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2232
2421
  }
2233
2422
 
2234
2423
  sumf = hsum_float_8(acc);
2424
+
2235
2425
  #elif defined(__loongarch_sx)
2236
2426
  // set constants
2237
2427
  const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
2238
2428
  const __m128i off = __lsx_vreplgr2vr_b(8);
2239
2429
 
2240
2430
  // Initialize accumulator with zeros
2241
- __m128 acc_0 = __lsx_vldi(0);
2242
- __m128 acc_1 = __lsx_vldi(0);
2243
- __m128 acc_2 = __lsx_vldi(0);
2244
- __m128 acc_3 = __lsx_vldi(0);
2431
+ __m128 acc_0 = (__m128)__lsx_vldi(0);
2432
+ __m128 acc_1 = (__m128)__lsx_vldi(0);
2433
+ __m128 acc_2 = (__m128)__lsx_vldi(0);
2434
+ __m128 acc_3 = (__m128)__lsx_vldi(0);
2245
2435
 
2246
2436
  for (; ib + 1 < nb; ib += 2) {
2247
2437
 
2248
2438
  // Compute combined scale for the block 0 and 1
2249
- const __m128 d_0_1 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) );
2439
+ const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d) );
2250
2440
 
2251
2441
  const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
2252
2442
 
@@ -2264,7 +2454,7 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2264
2454
  //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2265
2455
 
2266
2456
  // Compute combined scale for the block 2 and 3
2267
- const __m128 d_2_3 = __lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) );
2457
+ const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( LM_GGML_FP16_TO_FP32(x[ib + 1].d) * LM_GGML_FP16_TO_FP32(y[ib + 1].d) );
2268
2458
 
2269
2459
  const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
2270
2460
 
@@ -2696,10 +2886,10 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
2696
2886
  }
2697
2887
 
2698
2888
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2699
- #elif defined(__wasm_simd128__)
2889
+ #elif defined __wasm_simd128__
2700
2890
  v128_t sumv = wasm_f32x4_splat(0.0f);
2701
2891
 
2702
- uint32_t qh;
2892
+ uint32_t qh_;
2703
2893
  uint64_t tmp[4];
2704
2894
 
2705
2895
  // TODO: check if unrolling this is better
@@ -2710,12 +2900,12 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
2710
2900
  const v128_t m4b = wasm_i8x16_splat(0x0F);
2711
2901
 
2712
2902
  // extract the 5th bit
2713
- memcpy(&qh, x0->qh, sizeof(qh));
2903
+ memcpy(&qh_, x0->qh, sizeof(qh_));
2714
2904
 
2715
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2716
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2717
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2718
- tmp[3] = table_b2b_1[(qh >> 24) ];
2905
+ tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
2906
+ tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
2907
+ tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
2908
+ tmp[3] = table_b2b_1[(qh_ >> 24) ];
2719
2909
 
2720
2910
  const v128_t qhl = wasm_v128_load(tmp + 0);
2721
2911
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3057,12 +3247,12 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
3057
3247
  }
3058
3248
 
3059
3249
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3060
- #elif defined(__wasm_simd128__)
3250
+ #elif defined __wasm_simd128__
3061
3251
  v128_t sumv = wasm_f32x4_splat(0.0f);
3062
3252
 
3063
3253
  float summs = 0.0f;
3064
3254
 
3065
- uint32_t qh;
3255
+ uint32_t qh_;
3066
3256
  uint64_t tmp[4];
3067
3257
 
3068
3258
  // TODO: check if unrolling this is better
@@ -3075,12 +3265,12 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
3075
3265
  const v128_t m4b = wasm_i8x16_splat(0x0F);
3076
3266
 
3077
3267
  // extract the 5th bit
3078
- memcpy(&qh, x0->qh, sizeof(qh));
3268
+ memcpy(&qh_, x0->qh, sizeof(qh_));
3079
3269
 
3080
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3081
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3082
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3083
- tmp[3] = table_b2b_0[(qh >> 24) ];
3270
+ tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
3271
+ tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
3272
+ tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
3273
+ tmp[3] = table_b2b_0[(qh_ >> 24) ];
3084
3274
 
3085
3275
  const v128_t qhl = wasm_v128_load(tmp + 0);
3086
3276
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3573,6 +3763,45 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
3573
3763
  }
3574
3764
 
3575
3765
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3766
+ #elif defined __wasm_simd128__
3767
+ v128_t sumv = wasm_f32x4_splat(0.0f);
3768
+
3769
+ for (; ib < nb; ++ib) {
3770
+ const block_q8_0 * restrict x0 = &x[ib];
3771
+ const block_q8_0 * restrict y0 = &y[ib];
3772
+
3773
+ const v128_t x0_0 = wasm_v128_load(x0->qs);
3774
+ const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
3775
+ const v128_t y0_0 = wasm_v128_load(y0->qs);
3776
+ const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
3777
+
3778
+ // Extend 8-bit to 16-bit
3779
+ const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
3780
+ const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
3781
+ const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
3782
+ const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
3783
+
3784
+ const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
3785
+ const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
3786
+ const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
3787
+ const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
3788
+
3789
+ // Compute dot products
3790
+ const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
3791
+ const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
3792
+ const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
3793
+ const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
3794
+
3795
+ // Sum all dot products
3796
+ const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
3797
+
3798
+ // Convert to float and accumulate
3799
+ const float scale = LM_GGML_FP16_TO_FP32(x0->d) * LM_GGML_FP16_TO_FP32(y0->d);
3800
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
3801
+ }
3802
+
3803
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3804
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3576
3805
  #elif defined(__AVX2__)
3577
3806
  // Initialize accumulator with zeros
3578
3807
  __m256 acc = _mm256_setzero_ps();
@@ -4447,6 +4676,106 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4447
4676
 
4448
4677
  *s = hsum_float_8(acc);
4449
4678
 
4679
+ #elif defined __wasm_simd128__
4680
+ float sumf = 0;
4681
+
4682
+ for (int i = 0; i < nb; ++i) {
4683
+ const uint8_t * q2 = x[i].qs;
4684
+ const int8_t * q8 = y[i].qs;
4685
+ const uint8_t * sc = x[i].scales;
4686
+
4687
+ // Vectorized summs calculation
4688
+ v128_t summs_vec = wasm_i32x4_splat(0);
4689
+ {
4690
+ v128_t sc_vec = wasm_v128_load(sc);
4691
+ v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
4692
+
4693
+ v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
4694
+ v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
4695
+
4696
+ v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
4697
+ v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
4698
+
4699
+ summs_vec = wasm_i32x4_add(
4700
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
4701
+ wasm_i32x4_dot_i16x8(sc_high, bsums2)),
4702
+ summs_vec
4703
+ );
4704
+
4705
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
4706
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
4707
+ }
4708
+ int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
4709
+
4710
+ // Vectorized isum calculation
4711
+ int32_t isum = 0;
4712
+ const uint8_t * sc_ptr = sc;
4713
+ const int k_iters = QK_K/128;
4714
+
4715
+ for (int k = 0; k < k_iters; ++k) {
4716
+ v128_t isum_vec = wasm_i32x4_splat(0);
4717
+ int shift = 0;
4718
+
4719
+ for (int j = 0; j < 4; ++j) {
4720
+ const int d0 = (sc_ptr[0] & 0xF);
4721
+ const int d1 = (sc_ptr[1] & 0xF);
4722
+ sc_ptr += 2;
4723
+
4724
+ // Process first 16 elements
4725
+ v128_t q2_0 = wasm_v128_load(q2);
4726
+ v128_t q8_0 = wasm_v128_load(q8);
4727
+ v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
4728
+ v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
4729
+
4730
+ // Process next 16 elements
4731
+ v128_t q2_1 = wasm_v128_load(q2 + 16);
4732
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
4733
+ v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
4734
+ v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
4735
+
4736
+ // Calculate dot products
4737
+ v128_t p0 = wasm_i32x4_dot_i16x8(
4738
+ wasm_i16x8_extend_low_i8x16(q8_0),
4739
+ wasm_i16x8_extend_low_i8x16(q2_bits_0)
4740
+ );
4741
+ v128_t p1 = wasm_i32x4_dot_i16x8(
4742
+ wasm_i16x8_extend_high_i8x16(q8_0),
4743
+ wasm_i16x8_extend_high_i8x16(q2_bits_0)
4744
+ );
4745
+ v128_t p2 = wasm_i32x4_dot_i16x8(
4746
+ wasm_i16x8_extend_low_i8x16(q8_1),
4747
+ wasm_i16x8_extend_low_i8x16(q2_bits_1)
4748
+ );
4749
+ v128_t p3 = wasm_i32x4_dot_i16x8(
4750
+ wasm_i16x8_extend_high_i8x16(q8_1),
4751
+ wasm_i16x8_extend_high_i8x16(q2_bits_1)
4752
+ );
4753
+
4754
+ // Accumulate scaled results
4755
+ v128_t scaled = wasm_i32x4_add(
4756
+ wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
4757
+ wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
4758
+ );
4759
+
4760
+ isum_vec = wasm_i32x4_add(isum_vec, scaled);
4761
+ q8 += 32;
4762
+ shift += 2;
4763
+ }
4764
+ q2 += 32;
4765
+
4766
+ // Horizontal sum of isum_vec
4767
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
4768
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
4769
+ isum += wasm_i32x4_extract_lane(isum_vec, 0);
4770
+ }
4771
+
4772
+ const float dall = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4773
+ const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
4774
+ sumf += dall * isum - dmin * summs;
4775
+ }
4776
+
4777
+ *s = sumf;
4778
+
4450
4779
  #elif defined __riscv_v_intrinsic
4451
4780
 
4452
4781
  float sumf = 0;
@@ -4666,9 +4995,6 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4666
4995
 
4667
4996
  #elif defined __loongarch_asx
4668
4997
 
4669
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
4670
- const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
4671
-
4672
4998
  __m256 acc = (__m256)__lasx_xvldi(0);
4673
4999
 
4674
5000
  for (int i = 0; i < nb; ++i) {
@@ -4679,18 +5005,15 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4679
5005
  const uint8_t * restrict q2 = x[i].qs;
4680
5006
  const int8_t * restrict q8 = y[i].qs;
4681
5007
 
4682
- const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
4683
- const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
4684
- const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
4685
- const __m256i mins = lasx_ext8_16(mins8);
5008
+ const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
5009
+ const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
5010
+ const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
4686
5011
  const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
4687
5012
 
4688
5013
  acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
4689
5014
 
4690
- const __m256i all_scales = lasx_ext8_16(scales8);
4691
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
4692
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
4693
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5015
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5016
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
4694
5017
 
4695
5018
  __m256i sumi = __lasx_xvldi(0);
4696
5019
 
@@ -4703,20 +5026,20 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4703
5026
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4704
5027
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4705
5028
 
4706
- const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
4707
- const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
4708
- const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
4709
- const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
5029
+ const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
5030
+ const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
5031
+ const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
5032
+ const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
4710
5033
 
4711
- __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
4712
- __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
4713
- __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
4714
- __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
5034
+ __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
5035
+ __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
5036
+ __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
5037
+ __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
4715
5038
 
4716
- p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
4717
- p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
4718
- p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
4719
- p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
5039
+ p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
5040
+ p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
5041
+ p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
5042
+ p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
4720
5043
 
4721
5044
  p0 = __lasx_xvadd_w(p0, p1);
4722
5045
  p2 = __lasx_xvadd_w(p2, p3);
@@ -5129,6 +5452,94 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5129
5452
 
5130
5453
  *s = hsum_float_8(acc);
5131
5454
 
5455
+ #elif defined __wasm_simd128__
5456
+ int8_t aux8[QK_K];
5457
+ float sums[8] = {0};
5458
+ uint32_t auxs[4];
5459
+
5460
+ float sumf = 0;
5461
+ for (int i = 0; i < nb; ++i) {
5462
+ const uint8_t * restrict q3 = x[i].qs;
5463
+ const uint8_t * restrict hm = x[i].hmask;
5464
+ const int8_t * restrict q8 = y[i].qs;
5465
+
5466
+ // Process blocks with SIMD
5467
+ int8_t * a = aux8;
5468
+ uint8_t m = 1;
5469
+ for (int j = 0; j < QK_K; j += 128) {
5470
+ for (int shift = 0; shift <= 6; shift += 2) {
5471
+ v128_t v_m = wasm_i8x16_splat(m);
5472
+ for (int l = 0; l < 32; l += 16) {
5473
+ v128_t v_q3 = wasm_v128_load(q3 + l);
5474
+ v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
5475
+ v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
5476
+
5477
+ v128_t v_hm = wasm_v128_load(hm + l);
5478
+ v128_t v_mask = wasm_v128_and(v_hm, v_m);
5479
+ v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
5480
+
5481
+ v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
5482
+ wasm_v128_store(a + l, v_low2);
5483
+ }
5484
+ a += 32;
5485
+ m <<= 1;
5486
+ }
5487
+ q3 += 32;
5488
+ }
5489
+
5490
+ // Extract scales
5491
+ memcpy(auxs, x[i].scales, 12);
5492
+ uint32_t tmp = auxs[2];
5493
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
5494
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
5495
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
5496
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
5497
+ const int8_t * scales = (const int8_t *)auxs;
5498
+
5499
+ // SIMD dot product with register accumulators
5500
+ v128_t v_acc0 = wasm_i32x4_splat(0);
5501
+ v128_t v_acc1 = wasm_i32x4_splat(0);
5502
+ a = aux8;
5503
+ for (int j = 0; j < QK_K/16; ++j) {
5504
+ const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
5505
+
5506
+ // Process 16 elements per iteration
5507
+ for (int k = 0; k < 2; ++k) {
5508
+ const v128_t v_q8 = wasm_i16x8_load8x8(q8);
5509
+ const v128_t v_a = wasm_i16x8_load8x8(a);
5510
+
5511
+ v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
5512
+ v_prod = wasm_i16x8_mul(v_prod, v_scale);
5513
+
5514
+ v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
5515
+ v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
5516
+
5517
+ q8 += 8;
5518
+ a += 8;
5519
+ }
5520
+ }
5521
+
5522
+ // Accumulate results
5523
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5524
+ const v128_t v_d = wasm_f32x4_splat(d);
5525
+ v128_t v_sum = wasm_f32x4_add(
5526
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
5527
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
5528
+ );
5529
+
5530
+ // Accumulate into sums vector
5531
+ wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
5532
+ }
5533
+
5534
+ // Horizontal sum
5535
+ v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
5536
+ sumf = wasm_f32x4_extract_lane(v_sum, 0) +
5537
+ wasm_f32x4_extract_lane(v_sum, 1) +
5538
+ wasm_f32x4_extract_lane(v_sum, 2) +
5539
+ wasm_f32x4_extract_lane(v_sum, 3);
5540
+
5541
+ *s = sumf;
5542
+
5132
5543
  #elif defined __riscv_v_intrinsic
5133
5544
 
5134
5545
  uint32_t aux[3];
@@ -5384,8 +5795,6 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5384
5795
 
5385
5796
  #elif defined __loongarch_asx
5386
5797
 
5387
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5388
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
5389
5798
  const __m128i m32 = __lsx_vreplgr2vr_b(32);
5390
5799
 
5391
5800
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -5405,10 +5814,9 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5405
5814
  (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
5406
5815
  (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
5407
5816
  scales128 = __lsx_vsub_b(scales128, m32);
5408
- const __m256i all_scales = lasx_ext8_16(scales128);
5409
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
5410
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
5411
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5817
+
5818
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5819
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
5412
5820
 
5413
5821
  // high bit
5414
5822
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@@ -5416,35 +5824,23 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5416
5824
  // integer accumulator
5417
5825
  __m256i sumi = __lasx_xvldi(0);
5418
5826
 
5419
- int bit = 0;
5420
- int is = 0;
5421
- __m256i xvbit;
5422
-
5423
-
5424
5827
  for (int j = 0; j < QK_K/128; ++j) {
5425
5828
  // load low 2 bits
5426
5829
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
5427
5830
 
5428
- xvbit = __lasx_xvreplgr2vr_h(bit);
5429
5831
  // prepare low and high bits
5430
- const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
5431
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5432
- ++bit;
5433
-
5434
- xvbit = __lasx_xvreplgr2vr_h(bit);
5435
- const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
5436
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5437
- ++bit;
5438
-
5439
- xvbit = __lasx_xvreplgr2vr_h(bit);
5440
- const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
5441
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5442
- ++bit;
5443
-
5444
- xvbit = __lasx_xvreplgr2vr_h(bit);
5445
- const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
5446
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5447
- ++bit;
5832
+ const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
5833
+ const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
5834
+ const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
5835
+ const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
5836
+ const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
5837
+ const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
5838
+ const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
5839
+ const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
5840
+ const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
5841
+ const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
5842
+ const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
5843
+ const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
5448
5844
 
5449
5845
  // load Q8 quants
5450
5846
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
@@ -5452,29 +5848,16 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5452
5848
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5453
5849
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5454
5850
 
5455
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
5456
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
5457
- // and 2 if the high bit was set)
5458
- __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
5459
- __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
5460
- __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
5461
- __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
5462
-
5463
- __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
5464
- __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
5465
- __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
5466
- __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
5467
-
5468
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
5469
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
5470
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
5471
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
5851
+ __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
5852
+ __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
5853
+ __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
5854
+ __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
5472
5855
 
5473
5856
  // multiply with scales
5474
- p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
5475
- p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
5476
- p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
5477
- p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
5857
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
5858
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
5859
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
5860
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
5478
5861
 
5479
5862
  // accumulate
5480
5863
  p16_0 = __lasx_xvadd_w(p16_0, p16_1);
@@ -5482,7 +5865,7 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5482
5865
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
5483
5866
  }
5484
5867
  // multiply with block scale and accumulate
5485
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
5868
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
5486
5869
  }
5487
5870
 
5488
5871
  *s = hsum_float_8(acc);
@@ -5654,7 +6037,7 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
5654
6037
  }
5655
6038
  }
5656
6039
  *s = sumf;
5657
- #elif __ARM_NEON
6040
+ #elif defined __ARM_NEON
5658
6041
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5659
6042
  const int32x4_t mzero = vdupq_n_s32(0);
5660
6043
 
@@ -5717,6 +6100,107 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
5717
6100
 
5718
6101
  *s = sumf;
5719
6102
 
6103
+ #elif defined __wasm_simd128__
6104
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
6105
+ float sumf = 0;
6106
+
6107
+ for (int i = 0; i < nb; ++i) {
6108
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
6109
+ const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
6110
+
6111
+ const uint8_t * restrict q4 = x[i].qs;
6112
+ const int8_t * restrict q8 = y[i].qs;
6113
+
6114
+ // Process scales and mins
6115
+ memcpy(utmp, x[i].scales, 12);
6116
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6117
+ const uint32_t uaux = utmp[1] & kmask1;
6118
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6119
+ utmp[2] = uaux;
6120
+ utmp[0] &= kmask1;
6121
+
6122
+ // Sum mins * q8sums
6123
+ int32_t sumi = 0;
6124
+ const int16_t * restrict q8sums = y[i].bsums;
6125
+ const uint8_t * m = (const uint8_t *)&utmp[2];
6126
+ for (int j = 0; j < 16; j += 2) {
6127
+ sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
6128
+ }
6129
+ sumf -= dmin * sumi;
6130
+
6131
+ int32_t sumi1 = 0;
6132
+ int32_t sumi2 = 0;
6133
+
6134
+ for (int j = 0; j < QK_K/64; ++j) {
6135
+ // Load 64 4-bit weights (32 bytes)
6136
+ const v128_t q4x0 = wasm_v128_load(q4);
6137
+ const v128_t q4x1 = wasm_v128_load(q4 + 16);
6138
+ q4 += 32;
6139
+
6140
+ // Split into low/high nibbles
6141
+ const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
6142
+ const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
6143
+ const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
6144
+ const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
6145
+
6146
+ // Load 64 8-bit values (64 bytes)
6147
+ const v128_t q8x0 = wasm_v128_load(q8);
6148
+ const v128_t q8x1 = wasm_v128_load(q8 + 16);
6149
+ const v128_t q8x2 = wasm_v128_load(q8 + 32);
6150
+ const v128_t q8x3 = wasm_v128_load(q8 + 48);
6151
+ q8 += 64;
6152
+
6153
+ // Low nibble products
6154
+ v128_t vacc1 = wasm_i32x4_dot_i16x8(
6155
+ wasm_i16x8_extend_low_i8x16(q4l0),
6156
+ wasm_i16x8_extend_low_i8x16(q8x0)
6157
+ );
6158
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6159
+ wasm_i16x8_extend_high_i8x16(q4l0),
6160
+ wasm_i16x8_extend_high_i8x16(q8x0)
6161
+ ));
6162
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6163
+ wasm_i16x8_extend_low_i8x16(q4l1),
6164
+ wasm_i16x8_extend_low_i8x16(q8x1)
6165
+ ));
6166
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6167
+ wasm_i16x8_extend_high_i8x16(q4l1),
6168
+ wasm_i16x8_extend_high_i8x16(q8x1)
6169
+ ));
6170
+
6171
+ // High nibble products
6172
+ v128_t vacc2 = wasm_i32x4_dot_i16x8(
6173
+ wasm_i16x8_extend_low_i8x16(q4h0),
6174
+ wasm_i16x8_extend_low_i8x16(q8x2)
6175
+ );
6176
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6177
+ wasm_i16x8_extend_high_i8x16(q4h0),
6178
+ wasm_i16x8_extend_high_i8x16(q8x2)
6179
+ ));
6180
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6181
+ wasm_i16x8_extend_low_i8x16(q4h1),
6182
+ wasm_i16x8_extend_low_i8x16(q8x3)
6183
+ ));
6184
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6185
+ wasm_i16x8_extend_high_i8x16(q4h1),
6186
+ wasm_i16x8_extend_high_i8x16(q8x3)
6187
+ ));
6188
+
6189
+ // Accumulate scaled results
6190
+ int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
6191
+ wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
6192
+ sumi1 += vacc1_sum * scales[2*j];
6193
+
6194
+ int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
6195
+ wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
6196
+ sumi2 += vacc2_sum * scales[2*j+1];
6197
+ }
6198
+
6199
+ sumf += d * (sumi1 + sumi2);
6200
+ }
6201
+
6202
+ *s = sumf;
6203
+
5720
6204
  #elif defined __AVX2__
5721
6205
 
5722
6206
  const __m256i m4 = _mm256_set1_epi8(0xF);
@@ -6074,11 +6558,6 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
6074
6558
  *s = vec_extract(vsumf0, 0);
6075
6559
 
6076
6560
  #elif defined __loongarch_asx
6077
- LM_GGML_UNUSED(kmask1);
6078
- LM_GGML_UNUSED(kmask2);
6079
- LM_GGML_UNUSED(kmask3);
6080
-
6081
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
6082
6561
 
6083
6562
  __m256 acc = (__m256)__lasx_xvldi(0);
6084
6563
  __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6098,33 +6577,34 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
6098
6577
  const uint8_t * restrict q4 = x[i].qs;
6099
6578
  const int8_t * restrict q8 = y[i].qs;
6100
6579
 
6101
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
6580
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
6581
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
6582
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6102
6583
 
6103
6584
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6104
6585
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6105
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6586
+ const __m128i prod = lsx_madd_h(mins128, q8s);
6106
6587
  acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6107
6588
 
6108
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6109
- const __m256i scales = lasx_insertf128(sc128, sc128);
6589
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6110
6590
 
6111
6591
  __m256i sumi = __lasx_xvldi(0);
6112
6592
 
6113
6593
  for (int j = 0; j < QK_K/64; ++j) {
6114
6594
 
6115
- const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6116
- const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
6595
+ const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
6596
+ const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6117
6597
 
6118
6598
  const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
6119
- const __m256i q4l = __lasx_xvand_v(q4bits, m4);
6120
- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
6599
+ const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
6600
+ const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
6121
6601
 
6122
6602
  const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6123
- __m256i p16l = lasx_maddubs_h(q4l, q8l);
6603
+ __m256i p16l = lasx_madd_h_b(q4l, q8l);
6124
6604
  p16l = lasx_madd_h(scale_l, p16l);
6125
6605
 
6126
6606
  const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6127
- __m256i p16h = lasx_maddubs_h(q4h, q8h);
6607
+ __m256i p16h = lasx_madd_h_b(q4h, q8h);
6128
6608
  p16h = lasx_madd_h(scale_h, p16h);
6129
6609
  const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
6130
6610
 
@@ -6141,9 +6621,7 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
6141
6621
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
6142
6622
 
6143
6623
 
6144
- ft_union fi;
6145
- fi.i = __lsx_vpickve2gr_w(acc_m, 0);
6146
- *s = hsum_float_8(acc) + fi.f ;
6624
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
6147
6625
  #else
6148
6626
 
6149
6627
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6469,6 +6947,118 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6469
6947
 
6470
6948
  *s = hsum_float_8(acc) + summs;
6471
6949
 
6950
+ #elif defined __wasm_simd128__
6951
+ //const uint8_t * scales = (const uint8_t*)&utmp[0];
6952
+ float sumf = 0;
6953
+
6954
+ for (int i = 0; i < nb; ++i) {
6955
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
6956
+ const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
6957
+
6958
+ const uint8_t * restrict q5 = x[i].qs;
6959
+ const uint8_t * restrict qh = x[i].qh;
6960
+ const int8_t * restrict q8 = y[i].qs;
6961
+
6962
+ // Process scales and mins
6963
+ memcpy(utmp, x[i].scales, 12);
6964
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6965
+ const uint32_t uaux = utmp[1] & kmask1;
6966
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6967
+ utmp[2] = uaux;
6968
+ utmp[0] &= kmask1;
6969
+
6970
+ // Sum mins * q8sums
6971
+ int32_t sumi_mins = 0;
6972
+ const int16_t * restrict q8sums = y[i].bsums;
6973
+ const uint8_t * m = (const uint8_t *)&utmp[2];
6974
+ for (int j = 0; j < 16; j += 2) {
6975
+ sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
6976
+ }
6977
+ sumf -= dmin * sumi_mins; // Correct subtraction
6978
+
6979
+ v128_t qh0 = wasm_v128_load(qh);
6980
+ v128_t qh1 = wasm_v128_load(qh + 16);
6981
+ const uint8_t * sc = (const uint8_t *)utmp;
6982
+
6983
+ int32_t sumi = 0;
6984
+
6985
+ for (int j = 0; j < QK_K/64; ++j) {
6986
+ const int shift = j * 2;
6987
+ v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
6988
+ v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
6989
+
6990
+ v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
6991
+ v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
6992
+ v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
6993
+ v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
6994
+
6995
+ v128_t q5_0 = wasm_v128_load(q5);
6996
+ v128_t q5_1 = wasm_v128_load(q5 + 16);
6997
+ q5 += 32;
6998
+
6999
+ v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
7000
+ v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
7001
+ v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
7002
+ v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
7003
+
7004
+ v128_t q8_0 = wasm_v128_load(q8);
7005
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
7006
+ v128_t q8_2 = wasm_v128_load(q8 + 32);
7007
+ v128_t q8_3 = wasm_v128_load(q8 + 48);
7008
+ q8 += 64;
7009
+
7010
+ // Process low quants
7011
+ v128_t pl0 = wasm_i32x4_dot_i16x8(
7012
+ wasm_i16x8_extend_low_i8x16(q5l_0),
7013
+ wasm_i16x8_extend_low_i8x16(q8_0)
7014
+ );
7015
+ pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
7016
+ wasm_i16x8_extend_high_i8x16(q5l_0),
7017
+ wasm_i16x8_extend_high_i8x16(q8_0)
7018
+ ));
7019
+ v128_t pl1 = wasm_i32x4_dot_i16x8(
7020
+ wasm_i16x8_extend_low_i8x16(q5l_1),
7021
+ wasm_i16x8_extend_low_i8x16(q8_1)
7022
+ );
7023
+ pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
7024
+ wasm_i16x8_extend_high_i8x16(q5l_1),
7025
+ wasm_i16x8_extend_high_i8x16(q8_1)
7026
+ ));
7027
+ v128_t sum_low = wasm_i32x4_add(pl0, pl1);
7028
+
7029
+ // Process high quants
7030
+ v128_t ph0 = wasm_i32x4_dot_i16x8(
7031
+ wasm_i16x8_extend_low_i8x16(q5h_0),
7032
+ wasm_i16x8_extend_low_i8x16(q8_2)
7033
+ );
7034
+ ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
7035
+ wasm_i16x8_extend_high_i8x16(q5h_0),
7036
+ wasm_i16x8_extend_high_i8x16(q8_2)
7037
+ ));
7038
+ v128_t ph1 = wasm_i32x4_dot_i16x8(
7039
+ wasm_i16x8_extend_low_i8x16(q5h_1),
7040
+ wasm_i16x8_extend_low_i8x16(q8_3)
7041
+ );
7042
+ ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
7043
+ wasm_i16x8_extend_high_i8x16(q5h_1),
7044
+ wasm_i16x8_extend_high_i8x16(q8_3)
7045
+ ));
7046
+ v128_t sum_high = wasm_i32x4_add(ph0, ph1);
7047
+
7048
+ // Accumulate with scale factors
7049
+ int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
7050
+ wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
7051
+ int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
7052
+ wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
7053
+
7054
+ sumi += sl * sc[2*j] + sh * sc[2*j+1];
7055
+ }
7056
+
7057
+ sumf += d * sumi;
7058
+ }
7059
+
7060
+ *s = sumf;
7061
+
6472
7062
  #elif defined __riscv_v_intrinsic
6473
7063
 
6474
7064
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6691,19 +7281,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6691
7281
  *s = vec_extract(vsumf0, 0);
6692
7282
 
6693
7283
  #elif defined __loongarch_asx
6694
- LM_GGML_UNUSED(kmask1);
6695
- LM_GGML_UNUSED(kmask2);
6696
- LM_GGML_UNUSED(kmask3);
6697
-
6698
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
6699
- const __m128i mzero = __lsx_vldi(0);
6700
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
6701
7284
 
6702
7285
  __m256 acc = (__m256)__lasx_xvldi(0);
7286
+ __m128 acc_m = (__m128)__lsx_vldi(0);
6703
7287
 
6704
- float summs = 0.f;
6705
-
6706
- for (int i = 0; i < nb; ++i) {
7288
+ for (int i = 0; i < nb; ++i) {
6707
7289
 
6708
7290
  const uint8_t * restrict q5 = x[i].qs;
6709
7291
  const int8_t * restrict q8 = y[i].qs;
@@ -6718,49 +7300,40 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6718
7300
  utmp[2] = uaux;
6719
7301
  utmp[0] &= kmask1;
6720
7302
 
6721
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7303
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
7304
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
7305
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6722
7306
 
6723
7307
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6724
7308
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6725
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6726
- const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
6727
- summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
7309
+ const __m128i prod = lsx_madd_h(mins128, q8s);
7310
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6728
7311
 
6729
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6730
- const __m256i scales = lasx_insertf128(sc128, sc128);
7312
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6731
7313
 
6732
7314
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
6733
- __m256i hmask = mone;
6734
7315
 
6735
7316
  __m256i sumi = __lasx_xvldi(0);
6736
7317
 
6737
- int bit = 0;
6738
- __m256i xvbit;
6739
-
6740
7318
  for (int j = 0; j < QK_K/64; ++j) {
6741
7319
 
6742
- const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6743
- const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
7320
+ const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
7321
+ const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6744
7322
 
6745
7323
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
6746
7324
 
6747
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6748
- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
6749
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6750
- const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
6751
- hmask = __lasx_xvslli_h(hmask, 1);
6752
-
6753
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6754
- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
6755
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6756
- const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
6757
- hmask = __lasx_xvslli_h(hmask, 1);
7325
+ const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
7326
+ const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
7327
+ const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
7328
+ const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
7329
+ const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
7330
+ const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
6758
7331
 
6759
7332
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6760
7333
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6761
7334
 
6762
- __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
6763
- __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
7335
+ __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
7336
+ __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
6764
7337
 
6765
7338
  p16_0 = lasx_madd_h(scale_0, p16_0);
6766
7339
  p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -6774,7 +7347,10 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6774
7347
 
6775
7348
  }
6776
7349
 
6777
- *s = hsum_float_8(acc) + summs;
7350
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
7351
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
7352
+
7353
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
6778
7354
 
6779
7355
  #else
6780
7356
 
@@ -7132,6 +7708,85 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7132
7708
 
7133
7709
  *s = hsum_float_8(acc);
7134
7710
 
7711
+ #elif defined __wasm_simd128__
7712
+ int8_t aux8[QK_K] __attribute__((aligned(16)));
7713
+ int32_t aux32[8] __attribute__((aligned(16))) = {0};
7714
+ float sums[8] __attribute__((aligned(16))) = {0};
7715
+
7716
+ for (int i = 0; i < nb; ++i) {
7717
+ // Unpack 6-bit quantized data into aux8 (unchanged)
7718
+ const uint8_t * restrict q4 = x[i].ql;
7719
+ const uint8_t * restrict qh = x[i].qh;
7720
+ int8_t * a = aux8;
7721
+ for (int j = 0; j < QK_K; j += 128) {
7722
+ for (int l = 0; l < 32; ++l) {
7723
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
7724
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
7725
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
7726
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
7727
+ }
7728
+ a += 128;
7729
+ q4 += 64;
7730
+ qh += 32;
7731
+ }
7732
+
7733
+ const int8_t * restrict a_ptr = aux8;
7734
+ const int8_t * restrict q8 = y[i].qs;
7735
+ v128_t acc0 = wasm_i32x4_splat(0);
7736
+ v128_t acc1 = wasm_i32x4_splat(0);
7737
+
7738
+ for (int j = 0; j < QK_K/16; ++j) {
7739
+ const int scale = x[i].scales[j];
7740
+ const v128_t vscale = wasm_i32x4_splat(scale);
7741
+
7742
+ // Load 16 elements from a and q8
7743
+ const v128_t a_vec = wasm_v128_load(a_ptr);
7744
+ const v128_t q8_vec = wasm_v128_load(q8);
7745
+
7746
+ // Process low 8 elements
7747
+ v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
7748
+ v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
7749
+ v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
7750
+ v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
7751
+ v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
7752
+
7753
+ // Process high 8 elements
7754
+ v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
7755
+ v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
7756
+ v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
7757
+ v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
7758
+ v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
7759
+
7760
+ // Scale and accumulate
7761
+ prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
7762
+ prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
7763
+ prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
7764
+ prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
7765
+
7766
+ acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
7767
+ acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
7768
+
7769
+ a_ptr += 16;
7770
+ q8 += 16;
7771
+ }
7772
+
7773
+ // Store accumulated results
7774
+ wasm_v128_store(&aux32[0], acc0);
7775
+ wasm_v128_store(&aux32[4], acc1);
7776
+
7777
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7778
+ for (int l = 0; l < 8; ++l) {
7779
+ sums[l] += d * aux32[l];
7780
+ }
7781
+ }
7782
+
7783
+ // Sum final results
7784
+ float sumf = 0;
7785
+ for (int l = 0; l < 8; ++l) {
7786
+ sumf += sums[l];
7787
+ }
7788
+ *s = sumf;
7789
+
7135
7790
  #elif defined __riscv_v_intrinsic
7136
7791
 
7137
7792
  float sumf = 0;
@@ -7356,8 +8011,6 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7356
8011
 
7357
8012
  #elif defined __loongarch_asx
7358
8013
 
7359
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7360
- const __m256i m2 = __lasx_xvreplgr2vr_b(3);
7361
8014
  const __m256i m32s = __lasx_xvreplgr2vr_b(32);
7362
8015
 
7363
8016
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -7370,58 +8023,42 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7370
8023
  const uint8_t * restrict qh = x[i].qh;
7371
8024
  const int8_t * restrict q8 = y[i].qs;
7372
8025
 
7373
- const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
8026
+ const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
8027
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
8028
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
7374
8029
 
7375
8030
  __m256i sumi = __lasx_xvldi(0);
7376
8031
 
7377
- int is = 0;
7378
-
7379
8032
  for (int j = 0; j < QK_K/128; ++j) {
7380
8033
 
7381
- const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
7382
- const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
7383
- const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
7384
- const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
7385
- is += 4;
7386
-
7387
8034
  const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7388
8035
  const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7389
8036
  const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
7390
8037
 
7391
- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
7392
- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
7393
- const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
7394
- const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
8038
+ const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
8039
+ const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
8040
+ const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
8041
+ const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
7395
8042
 
7396
- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
7397
- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
7398
- const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
7399
- const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
8043
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
8044
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
8045
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
8046
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
7400
8047
 
7401
8048
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7402
8049
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7403
8050
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7404
8051
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7405
8052
 
7406
- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
7407
- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
7408
- __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
7409
- __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
8053
+ __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
8054
+ __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
8055
+ __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
8056
+ __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
7410
8057
 
7411
- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
7412
- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
7413
- __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
7414
- __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
7415
-
7416
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
7417
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
7418
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
7419
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
7420
-
7421
- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
7422
- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
7423
- p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
7424
- p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
8058
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
8059
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
8060
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
8061
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
7425
8062
 
7426
8063
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
7427
8064
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
@@ -9746,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9746
10383
  }
9747
10384
  #elif defined(__loongarch_asx)
9748
10385
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9749
- const __m256i ax = __lasx_xvsigncov_b(x, x);
9750
- const __m256i sy = __lasx_xvsigncov_b(x, y);
9751
- __m256i tmp1, tmp2, tmp3;
9752
- tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
9753
- tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
9754
- tmp3 = __lasx_xvadd_h(tmp1, tmp2);
9755
- return __lasx_xvsat_h(tmp3, 15);
10386
+ const __m256i a = __lasx_xvmulwev_h_b(x, y);
10387
+ const __m256i b = __lasx_xvmulwod_h_b(x, y);
10388
+ return __lasx_xvadd_h(a, b);
9756
10389
  }
9757
10390
  #endif
9758
10391
 
@@ -10802,67 +11435,31 @@ void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const voi
10802
11435
  #elif defined(__loongarch_asx)
10803
11436
 
10804
11437
  const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
10805
- const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
10806
11438
 
10807
11439
  __m256 accum = (__m256)__lasx_xvldi(0);
10808
- __m256i tmp1;
10809
- __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
10810
11440
 
10811
- mask_8f = __lsx_vreplgr2vr_b(0x8f);
10812
11441
  for (int ibl = 0; ibl < nb; ++ibl) {
10813
11442
  const uint8_t * qs = x[ibl].qs;
10814
11443
  const int8_t * q8 = y[ibl].qs;
10815
11444
  uint16_t sh = x[ibl].scales_h;
10816
11445
  __m256i sumi1 = __lasx_xvldi(0);
10817
11446
  __m256i sumi2 = __lasx_xvldi(0);
10818
- __m128i zero = __lsx_vldi(0);
10819
11447
  for (int ib = 0; ib < QK_K/32; ib += 2) {
10820
- const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10821
- const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11448
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11449
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10822
11450
  const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10823
11451
  const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10824
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
10825
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10826
- mask = __lsx_vsle_b(zero, tmp2);
10827
- tmp3 = __lsx_vand_v(tmp0, mask);
10828
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10829
-
10830
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
10831
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10832
- mask = __lsx_vsle_b(zero, tmp2);
10833
- tmp4 = __lsx_vand_v(tmp0, mask);
10834
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10835
-
10836
- const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
10837
-
10838
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
10839
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10840
- mask = __lsx_vsle_b(zero, tmp2);
10841
- tmp3 = __lsx_vand_v(tmp0, mask);
10842
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10843
-
10844
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
10845
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10846
- mask = __lsx_vsle_b(zero, tmp2);
10847
- tmp4 = __lsx_vand_v(tmp0, mask);
10848
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10849
-
10850
- const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
10851
-
11452
+ const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
11453
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
11454
+ const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
11455
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
10852
11456
  const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
10853
11457
  const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
10854
11458
  const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
10855
11459
  const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
10856
11460
  sh >>= 4;
10857
- __m256i tmp5, tmp6;
10858
- tmp1 = __lasx_xvreplgr2vr_h(ls1);
10859
- tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
10860
- tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
10861
- const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
10862
- tmp1 = __lasx_xvreplgr2vr_h(ls2);
10863
- tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
10864
- tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
10865
- const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
11461
+ const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
11462
+ const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
10866
11463
  sumi1 = __lasx_xvadd_w(p_1, sumi1);
10867
11464
  sumi2 = __lasx_xvadd_w(p_2, sumi2);
10868
11465
  }