whisper.rn 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. package/cpp/ggml-alloc.c +11 -4
  2. package/cpp/ggml-backend-reg.cpp +8 -0
  3. package/cpp/ggml-backend.cpp +0 -2
  4. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  5. package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
  6. package/cpp/ggml-cpu/ggml-cpu.c +50 -21
  7. package/cpp/ggml-cpu/ops.cpp +458 -349
  8. package/cpp/ggml-cpu/ops.h +4 -4
  9. package/cpp/ggml-cpu/repack.cpp +143 -29
  10. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  11. package/cpp/ggml-cpu/unary-ops.cpp +16 -0
  12. package/cpp/ggml-cpu/unary-ops.h +2 -0
  13. package/cpp/ggml-cpu/vec.cpp +17 -0
  14. package/cpp/ggml-cpu/vec.h +10 -0
  15. package/cpp/ggml-impl.h +17 -1
  16. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  17. package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
  18. package/cpp/ggml-metal/ggml-metal-device.h +8 -1
  19. package/cpp/ggml-metal/ggml-metal-device.m +216 -14
  20. package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
  21. package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
  22. package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
  23. package/cpp/ggml-metal/ggml-metal.cpp +5 -0
  24. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  25. package/cpp/ggml.c +154 -5
  26. package/cpp/ggml.h +73 -0
  27. package/cpp/whisper.cpp +5 -1
  28. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  29. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  33. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  34. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  39. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  40. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  41. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  45. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  47. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  48. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  50. package/package.json +1 -1
  51. package/whisper-rn.podspec +1 -1
  52. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  53. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
package/cpp/ggml-alloc.c CHANGED
@@ -226,16 +226,23 @@ static struct buffer_address wsp_ggml_dyn_tallocr_alloc(struct wsp_ggml_dyn_tall
226
226
  }
227
227
 
228
228
  if (best_fit_block == -1) {
229
- // no suitable block found, try the last block (this will grow a chunks size)
229
+ // no suitable block found, try the last block (this may grow a chunks size)
230
+ int64_t best_reuse = INT64_MIN;
230
231
  for (int c = 0; c < alloc->n_chunks; ++c) {
231
232
  struct tallocr_chunk * chunk = alloc->chunks[c];
232
233
  if (chunk->n_free_blocks > 0) {
233
234
  struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
234
235
  max_avail = MAX(max_avail, block->size);
235
- if (block->size >= size) {
236
+ int64_t reuse_factor = chunk->max_size - block->offset - size;
237
+ // reuse_factor < 0 : amount of extra memory that needs to be allocated
238
+ // reuse_factor = 0 : allocated free space exactly matches tensor size
239
+ // reuse_factor > 0 : superfluous memory that will remain unused
240
+ bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
241
+ bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
242
+ if (block->size >= size && (better_reuse || better_fit)) {
236
243
  best_fit_chunk = c;
237
244
  best_fit_block = chunk->n_free_blocks - 1;
238
- break;
245
+ best_reuse = reuse_factor;
239
246
  }
240
247
  }
241
248
  }
@@ -268,7 +275,7 @@ static struct buffer_address wsp_ggml_dyn_tallocr_alloc(struct wsp_ggml_dyn_tall
268
275
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
269
276
  add_allocated_tensor(alloc, addr, tensor);
270
277
  size_t cur_max = addr.offset + size;
271
- if (cur_max > alloc->max_size[addr.chunk]) {
278
+ if (cur_max > chunk->max_size) {
272
279
  // sort allocated_tensors by chunk/offset
273
280
  for (int i = 0; i < 1024; i++) {
274
281
  for (int j = i + 1; j < 1024; j++) {
@@ -57,6 +57,10 @@
57
57
  #include "ggml-opencl.h"
58
58
  #endif
59
59
 
60
+ #ifdef WSP_GGML_USE_HEXAGON
61
+ #include "ggml-hexagon.h"
62
+ #endif
63
+
60
64
  #ifdef WSP_GGML_USE_BLAS
61
65
  #include "ggml-blas.h"
62
66
  #endif
@@ -199,6 +203,9 @@ struct wsp_ggml_backend_registry {
199
203
  #ifdef WSP_GGML_USE_OPENCL
200
204
  register_backend(wsp_ggml_backend_opencl_reg());
201
205
  #endif
206
+ #ifdef WSP_GGML_USE_HEXAGON
207
+ register_backend(wsp_ggml_backend_hexagon_reg());
208
+ #endif
202
209
  #ifdef WSP_GGML_USE_CANN
203
210
  register_backend(wsp_ggml_backend_cann_reg());
204
211
  #endif
@@ -598,6 +605,7 @@ void wsp_ggml_backend_load_all_from_path(const char * dir_path) {
598
605
  wsp_ggml_backend_load_best("sycl", silent, dir_path);
599
606
  wsp_ggml_backend_load_best("vulkan", silent, dir_path);
600
607
  wsp_ggml_backend_load_best("opencl", silent, dir_path);
608
+ wsp_ggml_backend_load_best("hexagon", silent, dir_path);
601
609
  wsp_ggml_backend_load_best("musa", silent, dir_path);
602
610
  wsp_ggml_backend_load_best("cpu", silent, dir_path);
603
611
  // check the environment variable WSP_GGML_BACKEND_PATH to load an out-of-tree backend
@@ -1698,8 +1698,6 @@ bool wsp_ggml_backend_sched_reserve(wsp_ggml_backend_sched_t sched, struct wsp_g
1698
1698
  WSP_GGML_ASSERT(sched);
1699
1699
  WSP_GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
1700
1700
 
1701
- wsp_ggml_backend_sched_reset(sched);
1702
-
1703
1701
  wsp_ggml_backend_sched_synchronize(sched);
1704
1702
 
1705
1703
  wsp_ggml_backend_sched_split_graph(sched, measure_graph);
@@ -2044,6 +2044,26 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2044
2044
 
2045
2045
  }
2046
2046
 
2047
+ #ifdef __ARM_FEATURE_SVE
2048
+ static inline svuint32_t wsp_ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
2049
+ const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
2050
+ const svbool_t pg_false = svpfalse_b(); // 0x0000
2051
+ const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
2052
+ const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
2053
+
2054
+ svuint32_t vutmp_hi, vutmp_lo;
2055
+ svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
2056
+ vutmp_hi = svzip1_u32(vx01, vx01);
2057
+ vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
2058
+ vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
2059
+ const svuint32_t vx2 = svdup_u32(vx_scales[2]);
2060
+ vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
2061
+ vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
2062
+ svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
2063
+ return vutmp;
2064
+ }
2065
+ #endif
2066
+
2047
2067
  void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, size_t bx, const void * WSP_GGML_RESTRICT vy, size_t by, int nrc) {
2048
2068
  assert(n % QK_K == 0);
2049
2069
  #ifdef __ARM_FEATURE_MATMUL_INT8
@@ -2066,8 +2086,220 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2066
2086
  static const uint32_t kmask3 = 0x03030303;
2067
2087
 
2068
2088
  uint32_t utmp[4];
2089
+ #ifdef __ARM_FEATURE_SVE
2090
+ const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
2091
+ #endif
2069
2092
 
2070
- #if defined(__ARM_FEATURE_MATMUL_INT8)
2093
+ #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2094
+ if (nrc == 2) {
2095
+ svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2096
+
2097
+ const block_q4_K * WSP_GGML_RESTRICT vx0 = vx;
2098
+ const block_q8_K * WSP_GGML_RESTRICT vy0 = vy;
2099
+ const block_q4_K * WSP_GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
2100
+ const block_q8_K * WSP_GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2101
+
2102
+ union {
2103
+ uint32_t u32[8];
2104
+ uint64_t u64[4];
2105
+ } new_utmp;
2106
+
2107
+ svfloat32_t sumf1 = svdup_n_f32(0);
2108
+
2109
+ switch (vector_length) {
2110
+ case 128:
2111
+ {
2112
+ svbool_t pg_false = svpfalse_b();
2113
+ svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
2114
+ svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
2115
+ svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
2116
+ svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
2117
+ for (int i = 0; i < nb; ++i) {
2118
+ svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2119
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].d)));
2120
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2121
+ svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].dmin)));
2122
+ svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2123
+ svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
2124
+ const uint8_t * WSP_GGML_RESTRICT q4_0 = vx0[i].qs;
2125
+ const int8_t * WSP_GGML_RESTRICT q8_0 = vy0[i].qs;
2126
+ const uint8_t * WSP_GGML_RESTRICT q4_1 = vx1[i].qs;
2127
+ const int8_t * WSP_GGML_RESTRICT q8_1 = vy1[i].qs;
2128
+ svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
2129
+ svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
2130
+ svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
2131
+ svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
2132
+ svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2133
+ lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
2134
+ hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
2135
+ sum_tmp1 = svuzp1(lo, hi);
2136
+ sum_tmp2 = svuzp2(lo, hi);
2137
+ svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
2138
+ svuint32_t decoded_scales0 = wsp_ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2139
+ svuint32_t decoded_scales1 = wsp_ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2140
+ svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2141
+ svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
2142
+ svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
2143
+ svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
2144
+ svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
2145
+ svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
2146
+ svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
2147
+ svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
2148
+ svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
2149
+ svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
2150
+ svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2151
+ svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
2152
+ svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
2153
+ svint32_t svscales, sumi1, sumi2;
2154
+ svint32_t acc_sumif1 = svdup_n_s32(0);
2155
+ svint32_t acc_sumif2 = svdup_n_s32(0);
2156
+ svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
2157
+ q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
2158
+ #pragma GCC unroll 1
2159
+ for (int j = 0; j < QK_K/64; ++j) {
2160
+ q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
2161
+ q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
2162
+ q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
2163
+ q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
2164
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2165
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2166
+ l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2167
+ l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2168
+ q8bytes_0_h = svld1_s8(pg128_all, q8_0);
2169
+ q8bytes_1_h = svld1_s8(pg128_all, q8_1);
2170
+ q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
2171
+ q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
2172
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2173
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2174
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2175
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2176
+ sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2177
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2178
+ acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
2179
+
2180
+ q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
2181
+ q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
2182
+ q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
2183
+ q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
2184
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2185
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
2186
+ l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2187
+ l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
2188
+ q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
2189
+ q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
2190
+ q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
2191
+ q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
2192
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2193
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
2194
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2195
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
2196
+ sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
2197
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2198
+ acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
2199
+ q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2200
+ }
2201
+ sumf1 = svmla_f32_x(pg128_all,
2202
+ svmla_f32_x(pg128_all,
2203
+ sumf1,
2204
+ svcvt_f32_x(pg128_all,
2205
+ svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
2206
+ svsuper_block_scales),
2207
+ svdmins,
2208
+ svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
2209
+ } //end of for nb
2210
+ } // end of case 128
2211
+ break;
2212
+ case 256:
2213
+ case 512:
2214
+ {
2215
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2216
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
2217
+ const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2218
+ for (int i = 0; i < nb; ++i) {
2219
+ const uint8_t * WSP_GGML_RESTRICT q4_0 = vx0[i].qs;
2220
+ const int8_t * WSP_GGML_RESTRICT q8_0 = vy0[i].qs;
2221
+ const uint8_t * WSP_GGML_RESTRICT q4_1 = vx1[i].qs;
2222
+ const int8_t * WSP_GGML_RESTRICT q8_1 = vy1[i].qs;
2223
+ svint32_t svscales, sumi1, sumi2;
2224
+ svint32_t acc_sumif1 = svdup_n_s32(0);
2225
+ svint32_t acc_sumif2 = svdup_n_s32(0);
2226
+ svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
2227
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].d)));
2228
+ svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2229
+ svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2230
+ svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
2231
+ svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].dmin)));
2232
+ svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2233
+ svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
2234
+ svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
2235
+ svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2236
+ svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
2237
+ svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
2238
+ svuint32_t decoded_scales0 = wsp_ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
2239
+ svuint32_t decoded_scales1 = wsp_ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
2240
+ svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
2241
+ svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
2242
+ svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2243
+ svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
2244
+ svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
2245
+ svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
2246
+ svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
2247
+ svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
2248
+ svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
2249
+ svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
2250
+ svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
2251
+ svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
2252
+
2253
+ #pragma GCC unroll 1
2254
+ for (int j = 0; j < QK_K/64; ++j) {
2255
+ svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
2256
+ svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
2257
+ svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
2258
+ svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
2259
+ l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2260
+ l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
2261
+ l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2262
+ l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
2263
+ svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
2264
+ svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
2265
+ svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
2266
+ svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
2267
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2268
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2269
+ r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2270
+ r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
2271
+ sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
2272
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
2273
+ acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
2274
+ sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
2275
+ svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
2276
+ acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
2277
+ q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
2278
+ }
2279
+ svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
2280
+ svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
2281
+ acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
2282
+ sumf1 = svmla_f32_x(pg32_4,
2283
+ svmla_f32_x(pg32_4,
2284
+ sumf1,
2285
+ svcvt_f32_x(pg32_4, acc_sumif),
2286
+ svsuper_block_scales),
2287
+ svdmins,
2288
+ svsumfs_tmp);
2289
+ } // end of for nb
2290
+ } // end of case 256-512
2291
+ break;
2292
+ default:
2293
+ assert(false && "Unsupported vector length");
2294
+ break;
2295
+ }
2296
+
2297
+ svst1_f32(pg32_2, s, sumf1);
2298
+ svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
2299
+
2300
+ return;
2301
+ }
2302
+ #elif defined(__ARM_FEATURE_MATMUL_INT8)
2071
2303
  if (nrc == 2) {
2072
2304
  const block_q4_K * WSP_GGML_RESTRICT x0 = x;
2073
2305
  const block_q4_K * WSP_GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
@@ -2235,7 +2467,6 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2235
2467
  const uint8_t * WSP_GGML_RESTRICT q4 = x[i].qs;
2236
2468
  const int8_t * WSP_GGML_RESTRICT q8 = y[i].qs;
2237
2469
 
2238
- const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
2239
2470
  const svuint8_t m4b = svdup_n_u8(0xf);
2240
2471
  const svint32_t mzero = svdup_n_s32(0);
2241
2472
  svint32_t sumi1 = svdup_n_s32(0);
@@ -2480,7 +2711,201 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2480
2711
 
2481
2712
  const int nb = n / QK_K;
2482
2713
 
2483
- #if defined(__ARM_FEATURE_MATMUL_INT8)
2714
+ #ifdef __ARM_FEATURE_SVE
2715
+ const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
2716
+ #endif
2717
+ #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
2718
+ if (nrc == 2) {
2719
+ const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
2720
+
2721
+ svfloat32_t sum = svdup_n_f32(0);
2722
+
2723
+ const block_q6_K * WSP_GGML_RESTRICT vx0 = vx;
2724
+ const block_q8_K * WSP_GGML_RESTRICT vy0 = vy;
2725
+ const block_q6_K * WSP_GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
2726
+ const block_q8_K * WSP_GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
2727
+
2728
+ switch (vector_length) {
2729
+ case 128:
2730
+ {
2731
+ const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
2732
+ for (int i = 0; i < nb; ++i) {
2733
+ const uint8_t * WSP_GGML_RESTRICT ql0 = vx0[i].ql;
2734
+ const uint8_t * WSP_GGML_RESTRICT qh0 = vx0[i].qh;
2735
+ const uint8_t * WSP_GGML_RESTRICT ql1 = vx1[i].ql;
2736
+ const uint8_t * WSP_GGML_RESTRICT qh1 = vx1[i].qh;
2737
+ const int8_t * WSP_GGML_RESTRICT q80 = vy0[i].qs;
2738
+ const int8_t * WSP_GGML_RESTRICT q81 = vy1[i].qs;
2739
+
2740
+ const int8_t * WSP_GGML_RESTRICT scale0 = vx0[i].scales;
2741
+ const int8_t * WSP_GGML_RESTRICT scale1 = vx1[i].scales;
2742
+
2743
+ svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
2744
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].d)));
2745
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
2746
+ // process q8sum summation 128 bit route
2747
+ const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
2748
+ const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
2749
+ const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
2750
+ const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
2751
+ const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
2752
+ const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
2753
+ const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
2754
+ const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
2755
+ const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
2756
+ const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
2757
+ const svint64_t prod = svdup_n_s64(0);
2758
+
2759
+ svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
2760
+ svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
2761
+ svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
2762
+ svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
2763
+ svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
2764
+ svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
2765
+ svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2766
+ svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
2767
+ svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
2768
+
2769
+ // process mmla
2770
+ svint8_t l0, l1, r0, r1;
2771
+ svint32_t isum_tmp = svdup_n_s32(0);
2772
+ for (int j = 0; j < QK_K/128; ++j) {
2773
+ for (int k = 0; k < 8; ++k) {
2774
+ svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
2775
+ svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
2776
+ svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
2777
+ svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
2778
+ const int ql_pos = (k/4)*4;
2779
+ svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
2780
+ svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
2781
+ const int qh_pos = (k/2)*2;
2782
+ svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
2783
+ svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
2784
+ svint8_t q6bytes_0, q6bytes_1;
2785
+ if (qh_pos <= 4) {
2786
+ q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2787
+ q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2788
+ } else {
2789
+ q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
2790
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
2791
+ }
2792
+ svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
2793
+ svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
2794
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2795
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2796
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2797
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2798
+ svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2799
+ isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
2800
+ }
2801
+ qh0 += 32; qh1 += 32;
2802
+ ql0 += 64; ql1 += 64;
2803
+ q80 += 128; q81 += 128;
2804
+ scale0 += 8; scale1 += 8;
2805
+ }
2806
+ sum = svmla_f32_x(pg128_all, sum,
2807
+ svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
2808
+ svisum_mins, svdup_n_s32(-32))),
2809
+ svsuper_block_scales);
2810
+ }
2811
+ } // end of case 128
2812
+ break;
2813
+ case 256:
2814
+ case 512:
2815
+ {
2816
+ const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
2817
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
2818
+ for (int i = 0; i < nb; ++i) {
2819
+ const uint8_t * WSP_GGML_RESTRICT ql0 = vx0[i].ql;
2820
+ const uint8_t * WSP_GGML_RESTRICT qh0 = vx0[i].qh;
2821
+ const uint8_t * WSP_GGML_RESTRICT ql1 = vx1[i].ql;
2822
+ const uint8_t * WSP_GGML_RESTRICT qh1 = vx1[i].qh;
2823
+ const int8_t * WSP_GGML_RESTRICT q80 = vy0[i].qs;
2824
+ const int8_t * WSP_GGML_RESTRICT q81 = vy1[i].qs;
2825
+
2826
+ const int8_t * WSP_GGML_RESTRICT scale0 = vx0[i].scales;
2827
+ const int8_t * WSP_GGML_RESTRICT scale1 = vx1[i].scales;
2828
+ svfloat32_t vx_d = svzip1_f32(svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(WSP_GGML_FP16_TO_FP32(vx1[i].d)));
2829
+ svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
2830
+ svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
2831
+ svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
2832
+ // process q8sum summation 256 bit route
2833
+ const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
2834
+ const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
2835
+ const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
2836
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
2837
+ const svint64_t prod = svdup_n_s64(0);
2838
+ svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
2839
+ svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
2840
+ svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
2841
+ svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
2842
+ svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
2843
+ svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
2844
+ svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2845
+ svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
2846
+ svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
2847
+ svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
2848
+ svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
2849
+
2850
+ // process mmla
2851
+ svint8_t l0, l1, r0, r1;
2852
+ svint32_t isum_tmp = svdup_n_s32(0);
2853
+ for (int j = 0; j < QK_K/128; ++j) {
2854
+ for (int k = 0; k < 8; k+=2) { // process 2 block
2855
+ svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
2856
+ svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
2857
+ svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
2858
+ svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
2859
+ const int ql_pos = (k/4)*4;
2860
+ svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
2861
+ svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
2862
+ const int qh_pos = (k/2)*2;
2863
+ svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
2864
+ svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
2865
+ svint8_t q6bytes_0, q6bytes_1;
2866
+ if (qh_pos <= 4) {
2867
+ q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
2868
+ q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
2869
+ } else {
2870
+ q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
2871
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
2872
+ }
2873
+ svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
2874
+ svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
2875
+ l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2876
+ l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
2877
+ r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2878
+ r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
2879
+ svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
2880
+ svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
2881
+ isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
2882
+ isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
2883
+ }
2884
+ qh0 += 32; qh1 += 32;
2885
+ ql0 += 64; ql1 += 64;
2886
+ q80 += 128; q81 += 128;
2887
+ scale0 += 8; scale1 += 8;
2888
+ } // end of for
2889
+ svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
2890
+ isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
2891
+ sum = svmla_f32_x(pg32_4, sum,
2892
+ svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
2893
+ svisum_mins, svdup_n_s32(-32))),
2894
+ svsuper_block_scales);
2895
+ }
2896
+ } // end of case 256
2897
+ break;
2898
+ default:
2899
+ assert(false && "Unsupported vector length");
2900
+ break;
2901
+ } // end of switch
2902
+
2903
+ svst1_f32(pg32_2, s, sum);
2904
+ svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
2905
+
2906
+ return;
2907
+ }
2908
+ #elif defined(__ARM_FEATURE_MATMUL_INT8)
2484
2909
  if (nrc == 2) {
2485
2910
  const block_q6_K * WSP_GGML_RESTRICT x0 = x;
2486
2911
  const block_q6_K * WSP_GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
@@ -2594,27 +3019,6 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2594
3019
  // adjust bias, apply superblock scale
2595
3020
  {
2596
3021
  int32_t bias[4];
2597
- #ifdef __ARM_FEATURE_SVE
2598
- const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
2599
- const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
2600
- const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
2601
- const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
2602
- const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
2603
- const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
2604
- const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
2605
- const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
2606
- const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
2607
- const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
2608
- const svint64_t zero = svdup_n_s64(0);
2609
- bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
2610
- svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
2611
- bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
2612
- svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
2613
- bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
2614
- svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
2615
- bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
2616
- svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
2617
- #else
2618
3022
  // NEON doesn't support int16 dot product, fallback to separated mul and add
2619
3023
  const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
2620
3024
  const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
@@ -2646,7 +3050,6 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2646
3050
  vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
2647
3051
  bias[3] = vaddvq_s32(prod);
2648
3052
 
2649
- #endif
2650
3053
  const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
2651
3054
 
2652
3055
  const float32x4_t superblock_scale = {
@@ -2672,7 +3075,6 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, c
2672
3075
  #endif
2673
3076
 
2674
3077
  #ifdef __ARM_FEATURE_SVE
2675
- const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
2676
3078
  float sum = 0;
2677
3079
  svuint8_t m4b = svdup_n_u8(0xf);
2678
3080
  svint32_t vzero = svdup_n_s32(0);
@@ -500,13 +500,15 @@ inline static int32x4_t wsp_ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b
500
500
 
501
501
  #endif
502
502
 
503
- #if defined(__loongarch_asx)
503
+ #if defined(__loongarch_sx)
504
504
  /* float type data load instructions */
505
505
  static __m128 __lsx_vreplfr2vr_s(const float val) {
506
506
  v4f32 res = {val, val, val, val};
507
507
  return (__m128)res;
508
508
  }
509
+ #endif
509
510
 
511
+ #if defined(__loongarch_asx)
510
512
  static __m256 __lasx_xvreplfr2vr_s(const float val) {
511
513
  v8f32 res = {val, val, val, val, val, val, val, val};
512
514
  return (__m256)res;