cui-llama.rn 1.4.6 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +52 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1779
  14. package/cpp/chat.h +9 -1
  15. package/cpp/common.cpp +20 -522
  16. package/cpp/common.h +13 -36
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-common.h +12 -6
  19. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  20. package/cpp/ggml-cpu-impl.h +2 -21
  21. package/cpp/ggml-cpu-quants.c +904 -405
  22. package/cpp/ggml-cpu.c +909 -13237
  23. package/cpp/ggml-impl.h +50 -23
  24. package/cpp/ggml-metal-impl.h +77 -3
  25. package/cpp/ggml-metal.m +794 -580
  26. package/cpp/ggml.c +92 -3
  27. package/cpp/ggml.h +29 -5
  28. package/cpp/gguf.cpp +1 -0
  29. package/cpp/llama-adapter.cpp +55 -20
  30. package/cpp/llama-adapter.h +11 -9
  31. package/cpp/llama-arch.cpp +217 -16
  32. package/cpp/llama-arch.h +25 -0
  33. package/cpp/llama-batch.h +2 -2
  34. package/cpp/llama-chat.cpp +54 -2
  35. package/cpp/llama-chat.h +3 -0
  36. package/cpp/llama-context.cpp +2294 -1238
  37. package/cpp/llama-context.h +214 -77
  38. package/cpp/llama-cparams.h +1 -0
  39. package/cpp/llama-graph.cpp +1695 -0
  40. package/cpp/llama-graph.h +592 -0
  41. package/cpp/llama-hparams.cpp +8 -0
  42. package/cpp/llama-hparams.h +17 -0
  43. package/cpp/llama-io.cpp +15 -0
  44. package/cpp/llama-io.h +35 -0
  45. package/cpp/llama-kv-cache.cpp +965 -303
  46. package/cpp/llama-kv-cache.h +145 -151
  47. package/cpp/llama-memory.cpp +1 -0
  48. package/cpp/llama-memory.h +21 -0
  49. package/cpp/llama-mmap.cpp +1 -1
  50. package/cpp/llama-model-loader.cpp +10 -5
  51. package/cpp/llama-model-loader.h +5 -3
  52. package/cpp/llama-model.cpp +9194 -201
  53. package/cpp/llama-model.h +40 -1
  54. package/cpp/llama-sampling.cpp +5 -0
  55. package/cpp/llama-vocab.cpp +36 -5
  56. package/cpp/llama.cpp +51 -9984
  57. package/cpp/llama.h +102 -22
  58. package/cpp/log.cpp +34 -0
  59. package/cpp/minja/chat-template.hpp +15 -7
  60. package/cpp/minja/minja.hpp +120 -94
  61. package/cpp/ops.cpp +8723 -0
  62. package/cpp/ops.h +128 -0
  63. package/cpp/rn-llama.cpp +44 -53
  64. package/cpp/rn-llama.h +2 -12
  65. package/cpp/sampling.cpp +3 -0
  66. package/cpp/sgemm.cpp +533 -88
  67. package/cpp/simd-mappings.h +888 -0
  68. package/cpp/speculative.cpp +4 -4
  69. package/cpp/unary-ops.cpp +186 -0
  70. package/cpp/unary-ops.h +28 -0
  71. package/cpp/vec.cpp +258 -0
  72. package/cpp/vec.h +802 -0
  73. package/ios/CMakeLists.txt +5 -2
  74. package/ios/RNLlama.mm +2 -2
  75. package/ios/RNLlamaContext.mm +40 -24
  76. package/package.json +1 -1
  77. package/src/NativeRNLlama.ts +6 -4
  78. package/src/index.ts +3 -1
  79. package/cpp/chat-template.hpp +0 -529
  80. package/cpp/minja.hpp +0 -2915
@@ -891,15 +891,15 @@ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
891
891
  }
892
892
  #elif defined(__riscv_v_intrinsic)
893
893
 
894
- size_t vl = __riscv_vsetvl_e32m4(QK8_0);
894
+ size_t vl = QK8_0;
895
895
 
896
896
  for (int i = 0; i < nb; i++) {
897
897
  // load elements
898
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
898
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
899
899
 
900
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
900
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
901
901
  vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
902
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
902
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
903
903
  float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
904
904
 
905
905
  const float d = amax / ((1 << 7) - 1);
@@ -907,14 +907,14 @@ void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
907
907
 
908
908
  y[i].d = LM_GGML_FP32_TO_FP16(d);
909
909
 
910
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
910
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
911
911
 
912
912
  // convert to integer
913
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
914
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
913
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
914
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
915
915
 
916
916
  // store result
917
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
917
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
918
918
  }
919
919
 
920
920
  #elif defined(__POWER9_VECTOR__)
@@ -1229,15 +1229,15 @@ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
1229
1229
  }
1230
1230
  #elif defined(__riscv_v_intrinsic)
1231
1231
 
1232
- size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1232
+ size_t vl = QK8_1;
1233
1233
 
1234
1234
  for (int i = 0; i < nb; i++) {
1235
1235
  // load elements
1236
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1236
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
1237
1237
 
1238
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1238
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
1239
1239
  vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1240
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1240
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
1241
1241
  float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1242
1242
 
1243
1243
  const float d = amax / ((1 << 7) - 1);
@@ -1245,18 +1245,18 @@ void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT
1245
1245
 
1246
1246
  y[i].d = LM_GGML_FP32_TO_FP16(d);
1247
1247
 
1248
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1248
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
1249
1249
 
1250
1250
  // convert to integer
1251
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1252
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1251
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
1252
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
1253
1253
 
1254
1254
  // store result
1255
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1255
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
1256
1256
 
1257
1257
  // compute sum for y[i].s
1258
1258
  vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1259
- vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1259
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
1260
1260
 
1261
1261
  // set y[i].s
1262
1262
  int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
@@ -2391,33 +2391,31 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
2391
2391
 
2392
2392
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2393
2393
  #elif defined(__riscv_v_intrinsic)
2394
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2394
+ size_t vl = qk / 2;
2395
2395
 
2396
2396
  for (; ib < nb; ++ib) {
2397
2397
  // load elements
2398
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
2398
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
2399
2399
 
2400
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
2401
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
2400
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
2401
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
2402
2402
 
2403
2403
  // mask and store lower part of x, and then upper part
2404
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2405
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2404
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2405
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2406
2406
 
2407
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2408
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2407
+ vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2408
+ vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2409
2409
 
2410
2410
  // subtract offset
2411
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2412
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2411
+ vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2412
+ vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2413
2413
 
2414
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2415
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2414
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2415
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
2416
2416
 
2417
2417
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2418
-
2419
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2420
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2418
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2421
2419
 
2422
2420
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2423
2421
 
@@ -2783,29 +2781,27 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, con
2783
2781
 
2784
2782
  sumf = hsum_float_8(acc) + summs;
2785
2783
  #elif defined(__riscv_v_intrinsic)
2786
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2784
+ size_t vl = qk / 2;
2787
2785
 
2788
2786
  for (; ib < nb; ++ib) {
2789
2787
  // load elements
2790
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
2788
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
2791
2789
 
2792
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
2793
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
2790
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
2791
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
2794
2792
 
2795
2793
  // mask and store lower part of x, and then upper part
2796
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2797
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2794
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2795
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2798
2796
 
2799
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2800
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2797
+ vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2798
+ vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2801
2799
 
2802
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2803
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2800
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2801
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
2804
2802
 
2805
2803
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2806
-
2807
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2808
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2804
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2809
2805
 
2810
2806
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2811
2807
 
@@ -3132,65 +3128,33 @@ void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
3132
3128
 
3133
3129
  sumf = hsum_float_8(acc);
3134
3130
  #elif defined(__riscv_v_intrinsic)
3135
- uint32_t qh;
3136
-
3137
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3138
-
3139
- // These temporary registers are for masking and shift operations
3140
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3141
- vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3142
-
3143
- vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3144
- vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3131
+ size_t vl;
3132
+ size_t vlenb = __riscv_vlenb();
3145
3133
 
3146
3134
  for (; ib < nb; ++ib) {
3147
- memcpy(&qh, x[ib].qh, sizeof(uint32_t));
3148
-
3149
- // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3150
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3151
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3152
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3153
-
3154
- // ((qh & (1u << (j + 16))) >> (j + 12));
3155
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3156
- vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3157
-
3158
- // narrowing
3159
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3160
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3161
-
3162
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3163
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3164
-
3165
- // load
3166
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
3167
-
3168
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
3169
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
3170
-
3171
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3172
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3173
-
3174
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3175
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3176
-
3177
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3178
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3179
-
3180
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3181
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3182
-
3183
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3184
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3185
-
3186
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3187
-
3188
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3189
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3190
-
3191
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3192
-
3193
- sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi;
3135
+ vl = qk / 2;
3136
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
3137
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
3138
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
3139
+ vint8m2_t v0c;
3140
+ if (vlenb == 16) {
3141
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
3142
+ } else {
3143
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
3144
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
3145
+ }
3146
+
3147
+ vl = qk;
3148
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
3149
+ qh = __riscv_vmnand_mm_b4(qh, qh, vl);
3150
+ vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
3151
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3152
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
3153
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
3154
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
3155
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
3156
+
3157
+ sumf += (LM_GGML_FP16_TO_FP32(x[ib].d) * LM_GGML_FP16_TO_FP32(y[ib].d)) * sumi;
3194
3158
  }
3195
3159
 
3196
3160
  #elif defined(__POWER9_VECTOR__)
@@ -3503,60 +3467,30 @@ void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, con
3503
3467
 
3504
3468
  sumf = hsum_float_8(acc) + summs;
3505
3469
  #elif defined(__riscv_v_intrinsic)
3506
- uint32_t qh;
3507
-
3508
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3509
-
3510
- // temporary registers for shift operations
3511
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3512
- vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3470
+ size_t vl;
3471
+ size_t vlenb = __riscv_vlenb();
3513
3472
 
3514
3473
  for (; ib < nb; ++ib) {
3515
- memcpy(&qh, x[ib].qh, sizeof(uint32_t));
3516
-
3517
- // load qh
3518
- vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3519
-
3520
- // ((qh >> (j + 0)) << 4) & 0x10;
3521
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3522
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3523
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3524
-
3525
- // ((qh >> (j + 12)) ) & 0x10;
3526
- vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3527
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3528
-
3529
- // narrowing
3530
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3531
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3532
-
3533
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3534
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3535
-
3536
- // load
3537
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
3538
-
3539
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
3540
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
3541
-
3542
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3543
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3544
-
3545
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3546
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3547
-
3548
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3549
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3550
-
3551
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3552
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3553
-
3554
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3555
-
3556
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3557
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3558
-
3559
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3474
+ vl = qk / 2;
3475
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
3476
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
3477
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
3478
+ vint8m2_t v0c;
3479
+ if (vlenb == 16) {
3480
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
3481
+ } else {
3482
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
3483
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
3484
+ }
3485
+
3486
+ vl = qk;
3487
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
3488
+ vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
3489
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3490
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
3491
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
3492
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
3493
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
3560
3494
 
3561
3495
  sumf += (LM_GGML_FP16_TO_FP32(x[ib].d)*LM_GGML_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_FP16_TO_FP32(x[ib].m)*LM_GGML_FP16_TO_FP32(y[ib].s);
3562
3496
  }
@@ -3970,17 +3904,17 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, con
3970
3904
 
3971
3905
  sumf = hsum_float_8(accum);
3972
3906
  #elif defined(__riscv_v_intrinsic)
3973
- size_t vl = __riscv_vsetvl_e8m1(qk);
3907
+ size_t vl = qk;
3974
3908
 
3975
3909
  for (; ib < nb; ++ib) {
3976
3910
  // load elements
3977
- vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
3978
- vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
3911
+ vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
3912
+ vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3979
3913
 
3980
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
3914
+ vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
3981
3915
 
3982
3916
  vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3983
- vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3917
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
3984
3918
 
3985
3919
  int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
3986
3920
 
@@ -5174,84 +5108,182 @@ void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
5174
5108
 
5175
5109
  #elif defined __riscv_v_intrinsic
5176
5110
 
5111
+ const int vector_length = __riscv_vlenb() * 8;
5177
5112
  float sumf = 0;
5178
- uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5179
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
5180
-
5181
- for (int i = 0; i < nb; ++i) {
5182
5113
 
5183
- const uint8_t * q2 = x[i].qs;
5184
- const int8_t * q8 = y[i].qs;
5185
- const uint8_t * sc = x[i].scales;
5114
+ uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5115
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5116
+ uint8_t atmp[16];
5186
5117
 
5187
- const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
5188
- const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
5118
+ switch (vector_length) {
5119
+ case 256:
5120
+ for (int i = 0; i < nb; ++i) {
5121
+ const uint8_t * q2 = x[i].qs;
5122
+ const int8_t * q8 = y[i].qs;
5123
+ const uint8_t * sc = x[i].scales;
5189
5124
 
5190
- size_t vl = 16;
5125
+ const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
5126
+ const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
5191
5127
 
5192
- vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
5193
- vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
5128
+ size_t vl = 16;
5194
5129
 
5195
- vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
5130
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
5131
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
5196
5132
 
5197
- vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
5198
- vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
5199
- vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
5200
- vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
5201
- vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
5133
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
5202
5134
 
5203
- sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
5135
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
5136
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
5137
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
5138
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
5139
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
5204
5140
 
5205
- vl = 32;
5141
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
5206
5142
 
5207
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
5208
- vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
5143
+ vl = 32;
5209
5144
 
5210
- uint8_t is=0;
5211
- int isum=0;
5145
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
5146
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
5212
5147
 
5213
- for (int j = 0; j < QK_K/128; ++j) {
5214
- // load Q2
5215
- vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
5148
+ uint8_t is = 0;
5149
+ int isum = 0;
5216
5150
 
5217
- vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
5218
- vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
5219
- vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
5220
- vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
5151
+ for (int j = 0; j < QK_K / 128; ++j) {
5152
+ // load Q2
5153
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
5221
5154
 
5222
- // duplicate scale elements for product
5223
- vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
5224
- vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
5225
- vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
5226
- vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
5155
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
5156
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
5157
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
5158
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
5227
5159
 
5228
- vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
5229
- vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
5230
- vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
5231
- vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
5160
+ // duplicate scale elements for product
5161
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
5162
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
5163
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
5164
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
5232
5165
 
5233
- // load Q8
5234
- vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
5235
- vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
5236
- vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
5237
- vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
5166
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
5167
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
5168
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
5169
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
5238
5170
 
5239
- vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
5240
- vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
5241
- vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
5242
- vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
5171
+ // load Q8
5172
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
5173
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
5174
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
5175
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
5243
5176
 
5244
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
5245
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
5177
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
5178
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
5179
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
5180
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
5246
5181
 
5247
- isum += __riscv_vmv_x_s_i32m1_i32(isum1);
5182
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
5183
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
5248
5184
 
5249
- q2+=32; q8+=128; is=8;
5185
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
5250
5186
 
5251
- }
5187
+ q2 += 32;
5188
+ q8 += 128;
5189
+ is = 8;
5190
+ }
5252
5191
 
5253
- sumf += dall * isum;
5192
+ sumf += dall * isum;
5193
+ }
5194
+ break;
5195
+ case 128:
5196
+ for (int i = 0; i < nb; ++i) {
5197
+ const uint8_t * q2 = x[i].qs;
5198
+ const int8_t * q8 = y[i].qs;
5199
+ const uint8_t * sc = x[i].scales;
5200
+ const float dall = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
5201
+ const float dmin = -y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
5202
+ uint8_t *patmp = atmp;
5203
+ int vsums;
5204
+ int tmp;
5205
+ __asm__ __volatile__(
5206
+ "vsetivli zero, 16, e8, m1\n\t"
5207
+ "vmv.v.x v8, zero\n\t"
5208
+ "vle8.v v1, (%[sc])\n\t"
5209
+ "vand.vi v0, v1, 0xF\n\t"
5210
+ "vsrl.vi v1, v1, 4\n\t"
5211
+ "vse8.v v0, (%[scale])\n\t"
5212
+ "vsetivli zero, 16, e16, m2\n\t"
5213
+ "vle16.v v2, (%[bsums])\n\t"
5214
+ "vzext.vf2 v0, v1\n\t"
5215
+ "vwmul.vv v4, v0, v2\n\t"
5216
+ "vsetivli zero, 16, e32, m4\n\t"
5217
+ "vredsum.vs v8, v4, v8\n\t"
5218
+ "vmv.x.s %[vsums], v8"
5219
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
5220
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
5221
+ : "memory"
5222
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5223
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5224
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5225
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5226
+ );
5227
+ sumf += dmin * vsums;
5228
+ int isum = 0;
5229
+
5230
+ for (int j = 0; j < QK_K/128; ++j) {
5231
+ __asm__ __volatile__(
5232
+ "vsetvli zero, %[vl32], e8, m2\n\t"
5233
+ "vle8.v v0, (%[q2])\n\t"
5234
+ "vsrl.vi v2, v0, 2\n\t"
5235
+ "vsrl.vi v4, v0, 4\n\t"
5236
+ "vsrl.vi v6, v0, 6\n\t"
5237
+ "vand.vi v0, v0, 0x3\n\t"
5238
+ "vand.vi v2, v2, 0x3\n\t"
5239
+ "vand.vi v4, v4, 0x3\n\t"
5240
+ "vsetvli zero, %[vl128], e8, m8\n\t"
5241
+ "vle8.v v8, (%[q8])\n\t"
5242
+ "vsetvli zero, %[vl64], e8, m4\n\t"
5243
+ "vwmul.vv v16, v0, v8\n\t"
5244
+ "vwmul.vv v24, v4, v12\n\t"
5245
+ "vsetivli zero, 16, e16, m2\n\t"
5246
+ "vmv.v.x v0, zero\n\t"
5247
+ "vwredsum.vs v10, v16, v0\n\t"
5248
+ "vwredsum.vs v9, v18, v0\n\t"
5249
+ "vwredsum.vs v8, v20, v0\n\t"
5250
+ "vwredsum.vs v7, v22, v0\n\t"
5251
+ "vwredsum.vs v11, v24, v0\n\t"
5252
+ "vwredsum.vs v12, v26, v0\n\t"
5253
+ "vwredsum.vs v13, v28, v0\n\t"
5254
+ "vwredsum.vs v14, v30, v0\n\t"
5255
+ "vsetivli zero, 4, e32, m1\n\t"
5256
+ "vslideup.vi v10, v9, 1\n\t"
5257
+ "vslideup.vi v8, v7, 1\n\t"
5258
+ "vslideup.vi v11, v12, 1\n\t"
5259
+ "vslideup.vi v13, v14, 1\n\t"
5260
+ "vslideup.vi v10, v8, 2\n\t"
5261
+ "vslideup.vi v11, v13, 2\n\t"
5262
+ "vsetivli zero, 8, e32, m2\n\t"
5263
+ "vle8.v v15, (%[scale])\n\t"
5264
+ "vzext.vf4 v12, v15\n\t"
5265
+ "vmul.vv v10, v10, v12\n\t"
5266
+ "vredsum.vs v0, v10, v0\n\t"
5267
+ "vmv.x.s %[tmp], v0\n\t"
5268
+ "add %[isum], %[isum], %[tmp]"
5269
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
5270
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
5271
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
5272
+ : "memory"
5273
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5274
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5275
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5276
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5277
+ );
5278
+ q2 += 32; q8 += 128; patmp += 8;
5279
+ }
5254
5280
 
5281
+ sumf += dall * isum;
5282
+ }
5283
+ break;
5284
+ default:
5285
+ assert(false && "Unsupported vector length");
5286
+ break;
5255
5287
  }
5256
5288
 
5257
5289
  *s = sumf;
@@ -6116,97 +6148,221 @@ void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
6116
6148
  uint32_t aux[3];
6117
6149
  uint32_t utmp[4];
6118
6150
 
6151
+ const int vector_length = __riscv_vlenb() * 8;
6119
6152
  float sumf = 0;
6120
- for (int i = 0; i < nb; ++i) {
6121
6153
 
6122
- const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
6123
- const uint8_t * LM_GGML_RESTRICT qh = x[i].hmask;
6124
- const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
6154
+ switch (vector_length) {
6155
+ case 256:
6156
+ for (int i = 0; i < nb; ++i) {
6125
6157
 
6126
- memcpy(aux, x[i].scales, 12);
6127
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6128
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6129
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6130
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6158
+ const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
6159
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].hmask;
6160
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
6131
6161
 
6132
- int8_t * scale = (int8_t *)utmp;
6133
- for (int j = 0; j < 16; ++j) scale[j] -= 32;
6162
+ memcpy(aux, x[i].scales, 12);
6163
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6164
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6165
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6166
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6134
6167
 
6168
+ int8_t * scale = (int8_t *)utmp;
6169
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6135
6170
 
6136
- size_t vl = 32;
6137
- uint8_t m = 1;
6138
6171
 
6139
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6140
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
6172
+ size_t vl = 32;
6173
+ uint8_t m = 1;
6141
6174
 
6142
- int sum_t = 0;
6175
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6176
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
6143
6177
 
6144
- for (int j = 0; j < QK_K; j += 128) {
6178
+ int sum_t = 0;
6145
6179
 
6146
- vl = 32;
6180
+ for (int j = 0; j < QK_K; j += 128) {
6147
6181
 
6148
- // load Q3
6149
- vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
6182
+ vl = 32;
6150
6183
 
6151
- vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
6152
- vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
6153
- vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
6154
- vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
6184
+ // load Q3
6185
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
6155
6186
 
6156
- // compute mask for subtraction
6157
- vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
6158
- vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
6159
- vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
6160
- m <<= 1;
6187
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
6188
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
6189
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
6190
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
6161
6191
 
6162
- vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
6163
- vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
6164
- vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
6165
- m <<= 1;
6192
+ // compute mask for subtraction
6193
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
6194
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
6195
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
6196
+ m <<= 1;
6166
6197
 
6167
- vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
6168
- vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
6169
- vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
6170
- m <<= 1;
6198
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
6199
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
6200
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
6201
+ m <<= 1;
6171
6202
 
6172
- vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
6173
- vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
6174
- vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
6175
- m <<= 1;
6203
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
6204
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
6205
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
6206
+ m <<= 1;
6176
6207
 
6177
- // load Q8 and take product with Q3
6178
- vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
6179
- vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
6180
- vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
6181
- vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
6208
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
6209
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
6210
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
6211
+ m <<= 1;
6182
6212
 
6183
- vl = 16;
6213
+ // load Q8 and take product with Q3
6214
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
6215
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
6216
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
6217
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
6184
6218
 
6185
- // retrieve lane to multiply with scale
6186
- vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
6187
- vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
6188
- vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
6189
- vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
6190
- vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
6191
- vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
6192
- vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
6193
- vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
6219
+ vl = 16;
6194
6220
 
6195
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
6196
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
6197
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
6198
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
6221
+ // retrieve lane to multiply with scale
6222
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
6223
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
6224
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
6225
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
6226
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
6227
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
6228
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
6229
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
6199
6230
 
6200
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
6231
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
6232
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
6233
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
6234
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
6201
6235
 
6202
- q3 += 32; q8 += 128; scale += 8;
6236
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
6203
6237
 
6204
- }
6238
+ q3 += 32; q8 += 128; scale += 8;
6205
6239
 
6206
- const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6240
+ }
6241
+
6242
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6243
+
6244
+ sumf += d*sum_t;
6245
+
6246
+ }
6247
+ break;
6248
+ case 128:
6249
+ for (int i = 0; i < nb; ++i) {
6250
+ const uint8_t * restrict q3 = x[i].qs;
6251
+ const uint8_t * restrict qh = x[i].hmask;
6252
+ const int8_t * restrict q8 = y[i].qs;
6253
+
6254
+ int8_t * scale = (int8_t *)utmp;
6255
+ int tmp;
6256
+ __asm__ __volatile__(
6257
+ "vsetivli zero, 12, e8, m1\n\t"
6258
+ "vle8.v v0, (%[s6b])\n\t"
6259
+ "vmv1r.v v2, v0\n\t"
6260
+ "vsetivli zero, 2, e64, m1\n\t"
6261
+ "vmv.v.x v9, %[sh]\n\t"\
6262
+ "vslidedown.vi v1, v0, 1\n\t"
6263
+ "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
6264
+ "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
6265
+ "vsetivli zero, 4, e32, m1\n\t"
6266
+ "vid.v v9\n\t"
6267
+ "vmv.x.s %[tmp], v1\n\t"
6268
+ "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
6269
+ "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
6270
+ "vsrl.vv v4, v1, v9\n\t"
6271
+ "vsrl.vv v2, v0, v8\n\t"
6272
+ "vand.vx v5, v4, %[kmask1]\n\t"
6273
+ "vand.vx v3, v2, %[kmask2]\n\t"
6274
+ "vsll.vi v6, v5, 4\n\t"
6275
+ "vor.vv v7, v6, v3\n\t"
6276
+ "vsetivli zero, 16, e8, m1\n\t"
6277
+ "vsub.vx v0, v7, %[c]\n\t"
6278
+ "vse8.v v0, (%[scale])"
6279
+ : [tmp] "=&r" (tmp)
6280
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
6281
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
6282
+ : "memory"
6283
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6284
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6285
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6286
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6287
+ );
6207
6288
 
6208
- sumf += d*sum_t;
6289
+ uint8_t m = 1;
6290
+ int isum = 0;
6291
+ for (int j = 0; j < QK_K; j += 128) {
6292
+ __asm__ __volatile__(
6293
+ "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
6294
+ "vle8.v v8, (%[q3])\n\t"
6295
+ "vsrl.vi v10, v8, 2\n\t"
6296
+ "vsrl.vi v12, v8, 4\n\t"
6297
+ "vsrl.vi v14, v8, 6\n\t"
6298
+ "vand.vi v8, v8, 3\n\t"
6299
+ "vand.vi v10, v10, 3\n\t"
6300
+ "vand.vi v12, v12, 3\n\t"
6301
+ "vle8.v v2, (%[qh])\n\t"
6302
+ "vand.vx v4, v2, %[m]\n\t"
6303
+ "slli %[m], %[m], 1\n\t"
6304
+ "vmseq.vx v0, v4, zero\n\t"
6305
+ "vadd.vi v8, v8, -4, v0.t\n\t"
6306
+ "vand.vx v4, v2, %[m]\n\t"
6307
+ "slli %[m], %[m], 1\n\t"
6308
+ "vmseq.vx v0, v4, zero\n\t"
6309
+ "vadd.vi v10, v10, -4, v0.t\n\t"
6310
+ "vand.vx v4, v2, %[m]\n\t"
6311
+ "slli %[m], %[m], 1\n\t"
6312
+ "vmseq.vx v0, v4, zero\n\t"
6313
+ "vadd.vi v12, v12, -4, v0.t\n\t"
6314
+ "vand.vx v4, v2, %[m]\n\t"
6315
+ "slli %[m], %[m], 1\n\t"
6316
+ "vmseq.vx v0, v4, zero\n\t"
6317
+ "vadd.vi v14, v14, -4, v0.t\n\t"
6318
+ "vsetvli zero, %[vl128], e8, m8\n\t"
6319
+ "vle8.v v0, (%[q8])\n\t"
6320
+ "vsetvli zero, %[vl64], e8, m4\n\t"
6321
+ "vwmul.vv v16, v0, v8\n\t"
6322
+ "vwmul.vv v24, v4, v12\n\t"
6323
+ "vsetivli zero, 16, e16, m2\n\t"
6324
+ "vmv.v.x v0, zero\n\t"
6325
+ "vwredsum.vs v10, v16, v0\n\t"
6326
+ "vwredsum.vs v9, v18, v0\n\t"
6327
+ "vwredsum.vs v8, v20, v0\n\t"
6328
+ "vwredsum.vs v7, v22, v0\n\t"
6329
+ "vwredsum.vs v11, v24, v0\n\t"
6330
+ "vwredsum.vs v12, v26, v0\n\t"
6331
+ "vwredsum.vs v13, v28, v0\n\t"
6332
+ "vwredsum.vs v14, v30, v0\n\t"
6333
+ "vsetivli zero, 4, e32, m1\n\t"
6334
+ "vslideup.vi v10, v9, 1\n\t"
6335
+ "vslideup.vi v8, v7, 1\n\t"
6336
+ "vslideup.vi v11, v12, 1\n\t"
6337
+ "vslideup.vi v13, v14, 1\n\t"
6338
+ "vslideup.vi v10, v8, 2\n\t"
6339
+ "vslideup.vi v11, v13, 2\n\t"
6340
+ "vsetivli zero, 8, e32, m2\n\t"\
6341
+ "vle8.v v15, (%[scale])\n\t"
6342
+ "vsext.vf4 v12, v15\n\t"
6343
+ "vmul.vv v10, v10, v12\n\t"
6344
+ "vredsum.vs v0, v10, v0\n\t"
6345
+ "vmv.x.s %[tmp], v0\n\t"
6346
+ "add %[isum], %[isum], %[tmp]"
6347
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
6348
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
6349
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
6350
+ : "memory"
6351
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6352
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6353
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6354
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6355
+ );
6356
+ q3 += 32; q8 += 128; scale += 8;
6357
+ }
6209
6358
 
6359
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6360
+ sumf += d * isum;
6361
+ }
6362
+ break;
6363
+ default:
6364
+ assert(false && "Unsupported vector length");
6365
+ break;
6210
6366
  }
6211
6367
 
6212
6368
  *s = sumf;
@@ -6924,69 +7080,181 @@ void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
6924
7080
  const uint8_t * scales = (const uint8_t*)&utmp[0];
6925
7081
  const uint8_t * mins = (const uint8_t*)&utmp[2];
6926
7082
 
7083
+ const int vector_length = __riscv_vlenb() * 8;
6927
7084
  float sumf = 0;
6928
7085
 
6929
- for (int i = 0; i < nb; ++i) {
7086
+ switch (vector_length) {
7087
+ case 256:
7088
+ for (int i = 0; i < nb; ++i) {
6930
7089
 
6931
- size_t vl = 8;
7090
+ size_t vl = 8;
6932
7091
 
6933
- const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
6934
- const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
7092
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
7093
+ const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
6935
7094
 
6936
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
6937
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
6938
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
7095
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
7096
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
7097
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
6939
7098
 
6940
- memcpy(utmp, x[i].scales, 12);
6941
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6942
- const uint32_t uaux = utmp[1] & kmask1;
6943
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6944
- utmp[2] = uaux;
6945
- utmp[0] &= kmask1;
7099
+ memcpy(utmp, x[i].scales, 12);
7100
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7101
+ const uint32_t uaux = utmp[1] & kmask1;
7102
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7103
+ utmp[2] = uaux;
7104
+ utmp[0] &= kmask1;
6946
7105
 
6947
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
6948
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
6949
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
7106
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
7107
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
7108
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
6950
7109
 
6951
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
6952
- sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
7110
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
7111
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
6953
7112
 
6954
- const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
6955
- const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
7113
+ const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
7114
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
6956
7115
 
6957
- vl = 32;
7116
+ vl = 32;
6958
7117
 
6959
- int32_t sum_1 = 0;
6960
- int32_t sum_2 = 0;
7118
+ int32_t sum_1 = 0;
7119
+ int32_t sum_2 = 0;
6961
7120
 
6962
- vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
7121
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
6963
7122
 
6964
- for (int j = 0; j < QK_K/64; ++j) {
6965
- // load Q4
6966
- vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
7123
+ for (int j = 0; j < QK_K/64; ++j) {
7124
+ // load Q4
7125
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
6967
7126
 
6968
- // load Q8 and multiply it with lower Q4 nibble
6969
- vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
6970
- vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
6971
- vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
6972
- vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
7127
+ // load Q8 and multiply it with lower Q4 nibble
7128
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
7129
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
7130
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
7131
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
6973
7132
 
6974
- sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
7133
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
6975
7134
 
6976
- // load Q8 and multiply it with upper Q4 nibble
6977
- vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
6978
- vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
6979
- vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
6980
- vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
7135
+ // load Q8 and multiply it with upper Q4 nibble
7136
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
7137
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
7138
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
7139
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
6981
7140
 
6982
- sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
7141
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
6983
7142
 
6984
- q4 += 32; q8 += 64;
7143
+ q4 += 32; q8 += 64;
6985
7144
 
6986
- }
7145
+ }
7146
+
7147
+ sumf += d*(sum_1 + sum_2);
7148
+
7149
+ }
7150
+ break;
7151
+ case 128:
7152
+ for (int i = 0; i < nb; ++i) {
7153
+ const float d = y[i].d * LM_GGML_FP16_TO_FP32(x[i].d);
7154
+ const float dmin = y[i].d * LM_GGML_FP16_TO_FP32(x[i].dmin);
7155
+
7156
+ int tmp, tmp2, sumi;
7157
+ __asm__ __volatile__(
7158
+ "vsetivli zero, 12, e8, m1\n\t"
7159
+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
7160
+ "vsetivli zero, 4, e32, m1\n\t"
7161
+ "vslidedown.vi v2, v1, 2\n\t"
7162
+ "vmv1r.v v3, v2\n\t"
7163
+ "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
7164
+ "vsetivli zero, 2, e32, m1\n\t"
7165
+ "vmv.v.i v4, 4\n\t"
7166
+ "vand.vx v8, v1, %[kmask1]\n\t"
7167
+ "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
7168
+ "vsrl.vi v6, v1, 6\n\t"
7169
+ "vsrl.vv v7, v2, v5\n\t"
7170
+ "vand.vx v0, v6, %[kmask3]\n\t"
7171
+ "vand.vx v2, v7, %[kmask2]\n\t"
7172
+ "vsll.vi v6, v0, 4\n\t"
7173
+ "li %[t2], 8\n\t"
7174
+ "addi %[t1], %[utmp], 4\n\t"
7175
+ "vor.vv v1, v6, v2\n\t"
7176
+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
7177
+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
7178
+ "vsetivli zero, 8, e16, m1\n\t"
7179
+ "vle32.v v2, (%[bsums])\n\t"
7180
+ "vnsrl.wi v0, v2, 0\n\t"
7181
+ "vnsrl.wi v1, v2, 16\n\t"
7182
+ "vadd.vv v2, v0, v1\n\t"
7183
+ "vle8.v v3, (%[mins])\n\t"
7184
+ "vzext.vf2 v4, v3\n\t"
7185
+ "vwmul.vv v6, v4, v2\n\t"
7186
+ "vmv.v.x v0, zero\n\t"
7187
+ "vsetivli zero, 8, e32, m2\n\t"
7188
+ "vredsum.vs v0, v6, v0\n\t"
7189
+ "vmv.x.s %[sumi], v0"
7190
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
7191
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
7192
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
7193
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
7194
+ : "memory"
7195
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7196
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7197
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7198
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7199
+ );
7200
+ sumf -= dmin * sumi;
7201
+
7202
+ const uint8_t * restrict q4 = x[i].qs;
7203
+ const int8_t * restrict q8 = y[i].qs;
7204
+
7205
+ sumi = 0;
7206
+ const uint8_t * scale = scales;
7207
+
7208
+ for (int j = 0; j < QK_K/128; ++j) {
7209
+ int vl128 = 128, vl64 = 64, vl32 = 32;
7210
+ __asm__ __volatile__(
7211
+ "vsetvli zero, %[vl128], e8, m8\n\t"
7212
+ "vle8.v v8, (%[q8])\n\t"
7213
+ "vsetvli zero, %[vl64], e8, m4\n\t"
7214
+ "vle8.v v0, (%[q4])\n\t"
7215
+ "vsrl.vi v4, v0, 4\n\t"
7216
+ "vand.vi v0, v0, 0xF\n\t"
7217
+ "vsetvli zero, %[vl32], e8, m2\n\t"
7218
+ "vwmul.vv v28, v6, v14\n\t"
7219
+ "vwmul.vv v20, v4, v10\n\t"
7220
+ "vwmul.vv v24, v2, v12\n\t"
7221
+ "vwmul.vv v16, v0, v8\n\t"
7222
+ "vsetivli zero, 4, e32, m1\n\t"
7223
+ "vle8.v v2, (%[scale])\n\t"
7224
+ "vmv.v.x v0, zero\n\t"
7225
+ "vzext.vf4 v1, v2\n\t"
7226
+ "vsetvli zero, %[vl32], e16, m4\n\t"
7227
+ "vwredsum.vs v6, v24, v0\n\t"
7228
+ "vwredsum.vs v7, v28, v0\n\t"
7229
+ "vwredsum.vs v4, v16, v0\n\t"
7230
+ "vwredsum.vs v5, v20, v0\n\t"
7231
+ "vsetivli zero, 4, e32, m1\n\t"
7232
+ "vslideup.vi v6, v7, 1\n\t"
7233
+ "vslideup.vi v4, v5, 1\n\t"
7234
+ "vslideup.vi v4, v6, 2\n\t"
7235
+ "vmul.vv v8, v4, v1\n\t"
7236
+ "vredsum.vs v0, v8, v0\n\t"
7237
+ "vmv.x.s %[tmp], v0\n\t"
7238
+ "add %[sumi], %[sumi], %[tmp]"
7239
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
7240
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
7241
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
7242
+ : "memory"
7243
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7244
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7245
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7246
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7247
+ );
6987
7248
 
6988
- sumf += d*(sum_1 + sum_2);
7249
+ q4 += 64; q8 += 128; scale += 4;
7250
+ }
6989
7251
 
7252
+ sumf += d * sumi;
7253
+ }
7254
+ break;
7255
+ default:
7256
+ assert(false && "Unsupported vector length");
7257
+ break;
6990
7258
  }
6991
7259
 
6992
7260
  *s = sumf;
@@ -7722,9 +7990,9 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
7722
7990
  const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7723
7991
  const float dmin = LM_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
7724
7992
 
7725
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
7726
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
7727
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
7993
+ vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
7994
+ vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
7995
+ vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
7728
7996
 
7729
7997
  memcpy(utmp, x[i].scales, 12);
7730
7998
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -7733,11 +8001,11 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
7733
8001
  utmp[2] = uaux;
7734
8002
  utmp[0] &= kmask1;
7735
8003
 
7736
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
7737
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
7738
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
8004
+ vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
8005
+ vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
8006
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
7739
8007
 
7740
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
8008
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
7741
8009
  sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
7742
8010
 
7743
8011
  vl = 32;
@@ -7746,43 +8014,42 @@ void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
7746
8014
 
7747
8015
  uint8_t m = 1;
7748
8016
  vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
7749
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
8017
+ vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
7750
8018
 
7751
8019
  for (int j = 0; j < QK_K/64; ++j) {
7752
8020
  // load Q5 and Q8
7753
- vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
7754
- vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
7755
- vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
8021
+ vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
8022
+ vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
8023
+ vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
7756
8024
 
7757
8025
  // compute mask for addition
7758
- vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
7759
- vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
7760
- vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
7761
- vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
8026
+ vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
8027
+ vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
8028
+ vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
8029
+ vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
7762
8030
  m <<= 1;
7763
8031
 
7764
- vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
7765
- vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
7766
- vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
7767
- vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
8032
+ vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
8033
+ vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
8034
+ vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
8035
+ vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
7768
8036
  m <<= 1;
7769
8037
 
7770
- vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
7771
- vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
8038
+ vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
8039
+ vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
7772
8040
 
7773
- vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
7774
- vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
8041
+ vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
8042
+ vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
7775
8043
 
7776
- vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
7777
- vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
8044
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
8045
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
7778
8046
 
7779
- aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
8047
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
7780
8048
  q5 += 32; q8 += 64;
7781
8049
 
7782
8050
  }
7783
8051
 
7784
- vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
7785
- sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
8052
+ sums += aux32 * d;
7786
8053
 
7787
8054
  }
7788
8055
 
@@ -8158,7 +8425,156 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
8158
8425
 
8159
8426
  const int nb = n / QK_K;
8160
8427
 
8161
- #ifdef __ARM_NEON
8428
+ #ifdef __ARM_FEATURE_SVE
8429
+ const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
8430
+ float sum = 0;
8431
+ svuint8_t m4b = svdup_n_u8(0xf);
8432
+ svint32_t vzero = svdup_n_s32(0);
8433
+ svuint8_t mone = svdup_n_u8(0x30);
8434
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8435
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8436
+
8437
+ for (int i = 0; i < nb; ++i) {
8438
+ const float d_all = LM_GGML_FP16_TO_FP32(x[i].d);
8439
+
8440
+ const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
8441
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
8442
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
8443
+
8444
+ const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
8445
+
8446
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8447
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8448
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8449
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8450
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8451
+ const svint64_t prod = svdup_n_s64(0);
8452
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8453
+ svdot_s64(prod, q8sums_2, q6scales_2)));
8454
+ int32_t isum = 0;
8455
+
8456
+ switch (vector_length) {
8457
+ case 128:
8458
+ {
8459
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8460
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8461
+ svint32_t isum_tmp = svdup_n_s32(0);
8462
+ for (int j = 0; j < QK_K/128; ++j) {
8463
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8464
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8465
+ qh += 32;
8466
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8467
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8468
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8469
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8470
+ q6 += 64;
8471
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8472
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8473
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8474
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8475
+ q8 += 64;
8476
+
8477
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8478
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8479
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8480
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8481
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8482
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8483
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8484
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8485
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8486
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8487
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8488
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8489
+
8490
+ scale += 4;
8491
+ q8bytes_1 = svld1_s8(pg8_16, q8);
8492
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
8493
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
8494
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
8495
+ q8 += 64;
8496
+
8497
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8498
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8499
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8500
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8501
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8502
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8503
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8504
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8505
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8506
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8507
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8508
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8509
+ scale += 4;
8510
+ }
8511
+ isum += svaddv_s32(pg32_4, isum_tmp);
8512
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8513
+ }
8514
+ break;
8515
+ case 256:
8516
+ case 512:
8517
+ {
8518
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8519
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8520
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8521
+ svint32_t isum_tmp = svdup_n_s32(0);
8522
+ for (int j = 0; j < QK_K/128; j++) {
8523
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8524
+ qh += 32;
8525
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8526
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8527
+ q6 += 64;
8528
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8529
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8530
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8531
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8532
+ q8 += 128;
8533
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8534
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8535
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8536
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8537
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8538
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8539
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8540
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8541
+
8542
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8543
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8544
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8545
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8546
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8547
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8548
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8549
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8550
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8551
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8552
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8553
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8554
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8555
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8556
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8557
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8558
+
8559
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8560
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8561
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8562
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8563
+ scale += 8;
8564
+ }
8565
+ isum += svaddv_s32(pg32_8, isum_tmp);
8566
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8567
+ }
8568
+ break;
8569
+ default:
8570
+ assert(false && "Unsupported vector length");
8571
+ break;
8572
+ }
8573
+ }
8574
+
8575
+ *s = sum;
8576
+
8577
+ #elif __ARM_NEON
8162
8578
  float sum = 0;
8163
8579
 
8164
8580
  const uint8x16_t m4b = vdupq_n_u8(0xF);
@@ -8518,85 +8934,168 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, con
8518
8934
 
8519
8935
  #elif defined __riscv_v_intrinsic
8520
8936
 
8937
+ const int vector_length = __riscv_vlenb() * 8;
8521
8938
  float sumf = 0;
8522
- for (int i = 0; i < nb; ++i) {
8523
8939
 
8524
- const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8940
+ switch (vector_length) {
8941
+ case 256:
8942
+ for (int i = 0; i < nb; ++i) {
8525
8943
 
8526
- const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
8527
- const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
8528
- const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
8944
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8529
8945
 
8530
- const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
8946
+ const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
8947
+ const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
8948
+ const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
8531
8949
 
8532
- size_t vl;
8950
+ const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
8533
8951
 
8534
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
8952
+ size_t vl;
8535
8953
 
8536
- int sum_t = 0;
8537
- int is = 0;
8954
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
8538
8955
 
8539
- for (int j = 0; j < QK_K/128; ++j) {
8956
+ int sum_t = 0;
8957
+ int is = 0;
8540
8958
 
8541
- vl = 32;
8959
+ for (int j = 0; j < QK_K/128; ++j) {
8542
8960
 
8543
- // load qh
8544
- vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
8961
+ vl = 32;
8545
8962
 
8546
- // load Q6
8547
- vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
8548
- vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
8963
+ // load qh
8964
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
8549
8965
 
8550
- vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
8551
- vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
8552
- vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
8553
- vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
8966
+ // load Q6
8967
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
8968
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
8554
8969
 
8555
- vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
8556
- vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
8557
- vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
8558
- vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
8970
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
8971
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
8972
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
8973
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
8559
8974
 
8560
- vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
8561
- vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
8562
- vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
8563
- vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
8975
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
8976
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
8977
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
8978
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
8564
8979
 
8565
- vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
8566
- vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
8567
- vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
8568
- vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
8980
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
8981
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
8982
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
8983
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
8569
8984
 
8570
- // load Q8 and take product
8571
- vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
8572
- vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
8573
- vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
8574
- vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
8985
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
8986
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
8987
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
8988
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
8575
8989
 
8576
- vl = 16;
8990
+ // load Q8 and take product
8991
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
8992
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
8993
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
8994
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
8577
8995
 
8578
- vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
8579
- vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
8580
- vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
8581
- vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
8582
- vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
8583
- vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
8584
- vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
8585
- vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
8996
+ vl = 16;
8586
8997
 
8587
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
8588
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
8589
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
8590
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
8998
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
8999
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
9000
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
9001
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
9002
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
9003
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
9004
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
9005
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
8591
9006
 
8592
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
9007
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
9008
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
9009
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
9010
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
8593
9011
 
8594
- q6 += 64; qh += 32; q8 += 128; is=8;
9012
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
8595
9013
 
8596
- }
9014
+ q6 += 64; qh += 32; q8 += 128; is=8;
8597
9015
 
8598
- sumf += d * sum_t;
9016
+ }
8599
9017
 
9018
+ sumf += d * sum_t;
9019
+
9020
+ }
9021
+ break;
9022
+ case 128:
9023
+ for (int i = 0; i < nb; ++i) {
9024
+
9025
+ const float d = LM_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9026
+
9027
+ const uint8_t * restrict q6 = x[i].ql;
9028
+ const uint8_t * restrict qh = x[i].qh;
9029
+ const int8_t * restrict q8 = y[i].qs;
9030
+
9031
+ const int8_t * restrict scale = x[i].scales;
9032
+
9033
+ int sum_t = 0;
9034
+ int t0;
9035
+
9036
+ for (int j = 0; j < QK_K/128; ++j) {
9037
+ __asm__ __volatile__(
9038
+ "vsetvli zero, %[vl32], e8, m2\n\t"
9039
+ "vle8.v v4, (%[qh])\n\t"
9040
+ "vsll.vi v0, v4, 4\n\t"
9041
+ "vsll.vi v2, v4, 2\n\t"
9042
+ "vsrl.vi v6, v4, 2\n\t"
9043
+ "vsetvli zero, %[vl64], e8, m4\n\t"
9044
+ "vle8.v v8, (%[q6])\n\t"
9045
+ "vsrl.vi v12, v8, 4\n\t"
9046
+ "vand.vi v8, v8, 0xF\n\t"
9047
+ "vsetvli zero, %[vl128], e8, m8\n\t"
9048
+ "vand.vx v0, v0, %[mask]\n\t"
9049
+ "vor.vv v8, v8, v0\n\t"
9050
+ "vle8.v v0, (%[q8])\n\t"
9051
+ "vsub.vx v8, v8, %[vl32]\n\t"
9052
+ "vsetvli zero, %[vl64], e8, m4\n\t"
9053
+ "vwmul.vv v16, v0, v8\n\t"
9054
+ "vwmul.vv v24, v4, v12\n\t"
9055
+ "vsetivli zero, 16, e16, m2\n\t"
9056
+ "vmv.v.x v0, zero\n\t"
9057
+ "vwredsum.vs v10, v16, v0\n\t"
9058
+ "vwredsum.vs v9, v18, v0\n\t"
9059
+ "vwredsum.vs v8, v20, v0\n\t"
9060
+ "vwredsum.vs v7, v22, v0\n\t"
9061
+ "vwredsum.vs v11, v24, v0\n\t"
9062
+ "vwredsum.vs v12, v26, v0\n\t"
9063
+ "vwredsum.vs v13, v28, v0\n\t"
9064
+ "vwredsum.vs v14, v30, v0\n\t"
9065
+ "vsetivli zero, 4, e32, m1\n\t"
9066
+ "vslideup.vi v10, v9, 1\n\t"
9067
+ "vslideup.vi v8, v7, 1\n\t"
9068
+ "vslideup.vi v11, v12, 1\n\t"
9069
+ "vslideup.vi v13, v14, 1\n\t"
9070
+ "vslideup.vi v10, v8, 2\n\t"
9071
+ "vslideup.vi v11, v13, 2\n\t"
9072
+ "vsetivli zero, 8, e32, m2\n\t"
9073
+ "vle8.v v2, (%[scale])\n\t"
9074
+ "vsext.vf4 v4, v2\n\t"
9075
+ "vmul.vv v2, v4, v10\n\t"
9076
+ "vredsum.vs v0, v2, v0\n\t"
9077
+ "vmv.x.s %[t0], v0\n\t"
9078
+ "add %[sumi], %[sumi], %[t0]"
9079
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
9080
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
9081
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
9082
+ , [mask] "r" (0x30)
9083
+ : "memory"
9084
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
9085
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
9086
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
9087
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
9088
+ );
9089
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
9090
+ }
9091
+
9092
+ sumf += d * sum_t;
9093
+
9094
+ }
9095
+ break;
9096
+ default:
9097
+ assert(false && "Unsupported vector length");
9098
+ break;
8600
9099
  }
8601
9100
 
8602
9101
  *s = sumf;