whisper.rn 0.5.3 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/README.md +1 -1
  2. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  3. package/android/src/main/jni.cpp +13 -0
  4. package/cpp/ggml-alloc.c +78 -26
  5. package/cpp/ggml-alloc.h +9 -0
  6. package/cpp/ggml-backend-impl.h +1 -1
  7. package/cpp/ggml-backend-reg.cpp +19 -3
  8. package/cpp/ggml-backend.cpp +72 -20
  9. package/cpp/ggml-backend.h +2 -1
  10. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  12. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  13. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  14. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  16. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  17. package/cpp/ggml-cpu/ops.cpp +170 -18
  18. package/cpp/ggml-cpu/ops.h +1 -0
  19. package/cpp/ggml-cpu/repack.cpp +531 -5
  20. package/cpp/ggml-cpu/repack.h +14 -0
  21. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  22. package/cpp/ggml-cpu/vec.cpp +41 -1
  23. package/cpp/ggml-cpu/vec.h +241 -138
  24. package/cpp/ggml-cpu.h +1 -0
  25. package/cpp/ggml-impl.h +0 -4
  26. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  27. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  28. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  29. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  30. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  31. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  32. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  33. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  34. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  35. package/cpp/ggml.c +110 -31
  36. package/cpp/ggml.h +51 -12
  37. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  38. package/cpp/whisper.cpp +17 -4
  39. package/ios/CMakeLists.txt +21 -1
  40. package/ios/RNWhisperContext.mm +5 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  79. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  80. package/lib/commonjs/jest-mock.js +2 -0
  81. package/lib/commonjs/jest-mock.js.map +1 -1
  82. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
  83. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  84. package/lib/commonjs/version.json +1 -1
  85. package/lib/module/NativeRNWhisper.js.map +1 -1
  86. package/lib/module/jest-mock.js +2 -0
  87. package/lib/module/jest-mock.js.map +1 -1
  88. package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
  89. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  90. package/lib/module/version.json +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
  94. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  95. package/lib/typescript/realtime-transcription/types.d.ts +7 -0
  96. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  97. package/package.json +1 -1
  98. package/src/NativeRNWhisper.ts +1 -0
  99. package/src/jest-mock.ts +2 -0
  100. package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
  101. package/src/realtime-transcription/types.ts +9 -0
  102. package/src/version.json +1 -1
@@ -124,6 +124,58 @@ void wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic(const float * WSP_GGML_RESTRICT
124
124
  }
125
125
  }
126
126
 
127
+
128
+ void wsp_ggml_wsp_quantize_mat_q8_K_4x4_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k) {
129
+ assert(QK_K == 256);
130
+ assert(k % QK_K == 0);
131
+ const int nb = k / QK_K;
132
+
133
+ block_q8_Kx4 * WSP_GGML_RESTRICT y = (block_q8_Kx4 *) vy;
134
+
135
+ // scalar
136
+ const int blck_size_interleave = 4;
137
+ float srcv[4][QK_K];
138
+ float iscale[4];
139
+
140
+ for (int i = 0; i < nb; i++) {
141
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
142
+ float amax = 0.0f; // absolute max
143
+ float max = 0;
144
+
145
+ for (int j = 0; j < QK_K; j++) {
146
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
147
+ // Update the maximum value of the corresponding super block
148
+ if(amax < fabsf(srcv[row_iter][j])) {
149
+ amax = fabsf(srcv[row_iter][j]);
150
+ max = srcv[row_iter][j];
151
+ }
152
+ }
153
+
154
+ iscale[row_iter] = amax ? -127.f/max : 0;
155
+
156
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
157
+ }
158
+
159
+ for (int j = 0; j < QK_K / 4; j++) {
160
+ y[i].bsums[j] = 0;
161
+ }
162
+
163
+ // Quants values are interleaved in sequence of four bytes from corresponding super blocks
164
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
165
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
166
+ for (int j = 0; j < QK_K * 4; j++) {
167
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
168
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
169
+ src_offset += (j % blck_size_interleave);
170
+ int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
171
+
172
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
173
+ y[i].qs[j] = nearest_int(x0);
174
+ y[i].bsums[index] += y[i].qs[j];
175
+ }
176
+ }
177
+ }
178
+
127
179
  void wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k) {
128
180
  assert(QK_K == 256);
129
181
  assert(k % QK_K == 0);
@@ -192,6 +244,12 @@ template <> void wsp_ggml_wsp_quantize_mat_t<8, WSP_GGML_TYPE_Q8_0>(const float
192
244
  wsp_ggml_wsp_quantize_mat_q8_0_4x8(x, vy, n_per_row);
193
245
  }
194
246
 
247
+ template <> void wsp_ggml_wsp_quantize_mat_t<4, WSP_GGML_TYPE_Q8_K>(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
248
+ assert(nrow == 4);
249
+ UNUSED(nrow);
250
+ wsp_ggml_wsp_quantize_mat_q8_K_4x4(x, vy, n_per_row);
251
+ }
252
+
195
253
  template <> void wsp_ggml_wsp_quantize_mat_t<8, WSP_GGML_TYPE_Q8_K>(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
196
254
  assert(nrow == 4);
197
255
  UNUSED(nrow);
@@ -333,6 +391,77 @@ void wsp_ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, siz
333
391
  }
334
392
  }
335
393
 
394
+ void wsp_ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
395
+ const int qk = QK_K;
396
+ const int nb = n / qk;
397
+ const int ncols_interleaved = 8;
398
+ const int blocklen = 4;
399
+ static const uint32_t kmask1 = 0x3f3f3f3f;
400
+ static const uint32_t kmask2 = 0x0f0f0f0f;
401
+ static const uint32_t kmask3 = 0x03030303;
402
+
403
+ assert (n % qk == 0);
404
+ assert (nc % ncols_interleaved == 0);
405
+
406
+ UNUSED(bs);
407
+ UNUSED(nr);
408
+
409
+ float sumf[8];
410
+ float sum_minf[8];
411
+ uint32_t utmp[32];
412
+ int sumi1;
413
+ int sumi2;
414
+ int sumi;
415
+
416
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
417
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
418
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
419
+
420
+ for (int j = 0; j < ncols_interleaved; j++) {
421
+ sumf[j] = 0.0;
422
+ sum_minf[j] = 0.0;
423
+ }
424
+ for (int l = 0; l < nb; l++) {
425
+ for (int sb = 0; sb < 8; sb++) {
426
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
427
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
428
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
429
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
430
+ utmp[sb * 4 + 2] = uaux_0;
431
+ utmp[sb * 4 + 0] &= kmask1;
432
+ }
433
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
434
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
435
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
436
+ for (int j = 0; j < ncols_interleaved; j++) {
437
+ sumi1 = 0;
438
+ sumi2 = 0;
439
+ sumi = 0;
440
+ for (int i = 0; i < blocklen; ++i) {
441
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
442
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
443
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
444
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
445
+ sumi1 = sumi1 * scales_0[j];
446
+ sumi2 = sumi2 * scales_1[j];
447
+ sumi += sumi1 + sumi2;
448
+ }
449
+ sumf[j] += sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
450
+ }
451
+ }
452
+ for (int sb = 0; sb < 8; sb++) {
453
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
454
+ for (int j = 0; j < ncols_interleaved; j++) {
455
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
456
+ }
457
+ }
458
+ }
459
+ for (int j = 0; j < ncols_interleaved; j++) {
460
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
461
+ }
462
+ }
463
+ }
464
+
336
465
  void wsp_ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
337
466
  const int qk = QK_K;
338
467
  const int nb = n / qk;
@@ -563,6 +692,100 @@ void wsp_ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, s
563
692
  }
564
693
  }
565
694
 
695
+ void wsp_ggml_gemv_q8_0_4x4_q8_0_generic(int n,
696
+ float * WSP_GGML_RESTRICT s,
697
+ size_t bs,
698
+ const void * WSP_GGML_RESTRICT vx,
699
+ const void * WSP_GGML_RESTRICT vy,
700
+ int nr,
701
+ int nc) {
702
+ const int qk = QK8_0;
703
+ const int nb = n / qk;
704
+ const int ncols_interleaved = 4;
705
+ const int blocklen = 4;
706
+
707
+ assert(nr == 1);
708
+ assert(n % qk == 0);
709
+ assert(nc % ncols_interleaved == 0);
710
+
711
+ UNUSED(bs);
712
+ UNUSED(nr);
713
+
714
+ float sumf[4];
715
+ int sumi;
716
+
717
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
718
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
719
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
720
+
721
+ for (int j = 0; j < ncols_interleaved; j++) {
722
+ sumf[j] = 0.0;
723
+ }
724
+ for (int l = 0; l < nb; l++) {
725
+ for (int k = 0; k < (qk / blocklen); k++) {
726
+ for (int j = 0; j < ncols_interleaved; j++) {
727
+ sumi = 0;
728
+ for (int i = 0; i < blocklen; ++i) {
729
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
730
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
731
+ }
732
+ sumf[j] += sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
733
+ }
734
+ }
735
+ }
736
+ for (int j = 0; j < ncols_interleaved; j++) {
737
+ s[x * ncols_interleaved + j] = sumf[j];
738
+ }
739
+ }
740
+ }
741
+
742
+ void wsp_ggml_gemv_q8_0_4x8_q8_0_generic(int n,
743
+ float * WSP_GGML_RESTRICT s,
744
+ size_t bs,
745
+ const void * WSP_GGML_RESTRICT vx,
746
+ const void * WSP_GGML_RESTRICT vy,
747
+ int nr,
748
+ int nc) {
749
+ const int qk = QK8_0;
750
+ const int nb = n / qk;
751
+ const int ncols_interleaved = 4;
752
+ const int blocklen = 8;
753
+
754
+ assert(nr == 1);
755
+ assert(n % qk == 0);
756
+ assert(nc % ncols_interleaved == 0);
757
+
758
+ UNUSED(bs);
759
+ UNUSED(nr);
760
+
761
+ float sumf[4];
762
+ int sumi;
763
+
764
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
765
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
766
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
767
+
768
+ for (int j = 0; j < ncols_interleaved; j++) {
769
+ sumf[j] = 0.0;
770
+ }
771
+ for (int l = 0; l < nb; l++) {
772
+ for (int k = 0; k < (qk / blocklen); k++) {
773
+ for (int j = 0; j < ncols_interleaved; j++) {
774
+ sumi = 0;
775
+ for (int i = 0; i < blocklen; ++i) {
776
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
777
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
778
+ }
779
+ sumf[j] += sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
780
+ }
781
+ }
782
+ }
783
+ for (int j = 0; j < ncols_interleaved; j++) {
784
+ s[x * ncols_interleaved + j] = sumf[j];
785
+ }
786
+ }
787
+ }
788
+
566
789
  void wsp_ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
567
790
  const int qk = QK8_0;
568
791
  const int nb = n / qk;
@@ -727,6 +950,89 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, siz
727
950
  }
728
951
  }
729
952
 
953
+ void wsp_ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
954
+ const int qk = QK_K;
955
+ const int nb = n / qk;
956
+ const int ncols_interleaved = 8;
957
+ const int blocklen = 4;
958
+ static const uint32_t kmask1 = 0x3f3f3f3f;
959
+ static const uint32_t kmask2 = 0x0f0f0f0f;
960
+ static const uint32_t kmask3 = 0x03030303;
961
+
962
+ assert (n % qk == 0);
963
+ assert (nr % 4 == 0);
964
+ assert (nc % ncols_interleaved == 0);
965
+
966
+ UNUSED(nb);
967
+ UNUSED(ncols_interleaved);
968
+ UNUSED(blocklen);
969
+
970
+ float sumf[4][8];
971
+ float sum_minf[4][8];
972
+ uint32_t utmp[32];
973
+ int sumi1;
974
+ int sumi2;
975
+ int sumi;
976
+
977
+ for (int y = 0; y < nr / 4; y++) {
978
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
979
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
980
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
981
+ for (int m = 0; m < 4; m++) {
982
+ for (int j = 0; j < ncols_interleaved; j++) {
983
+ sumf[m][j] = 0.0;
984
+ sum_minf[m][j] = 0.0;
985
+ }
986
+ }
987
+ for (int l = 0; l < nb; l++) {
988
+ for (int sb = 0; sb < 8; sb++) {
989
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
990
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
991
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
992
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
993
+ utmp[sb * 4 + 2] = uaux_0;
994
+ utmp[sb * 4 + 0] &= kmask1;
995
+ }
996
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
997
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
998
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
999
+ for (int m = 0; m < 4; m++) {
1000
+ for (int j = 0; j < ncols_interleaved; j++) {
1001
+ sumi1 = 0;
1002
+ sumi2 = 0;
1003
+ sumi = 0;
1004
+ for (int i = 0; i < blocklen; ++i) {
1005
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1006
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1007
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
1008
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
1009
+ sumi1 = sumi1 * scales_0[j];
1010
+ sumi2 = sumi2 * scales_1[j];
1011
+ sumi += sumi1 + sumi2;
1012
+ }
1013
+ sumf[m][j] += sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
1014
+ }
1015
+ }
1016
+ }
1017
+ for (int sb = 0; sb < 8; sb++) {
1018
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
1019
+ for(int m = 0; m < 4; m++) {
1020
+ const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
1021
+ for(int j = 0; j < ncols_interleaved; j++) {
1022
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
1023
+ }
1024
+ }
1025
+ }
1026
+ }
1027
+ for (int m = 0; m < 4; m++) {
1028
+ for (int j = 0; j < ncols_interleaved; j++) {
1029
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
1030
+ }
1031
+ }
1032
+ }
1033
+ }
1034
+ }
1035
+
730
1036
  void wsp_ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
731
1037
  const int qk = QK_K;
732
1038
  const int nb = n / qk;
@@ -1007,8 +1313,129 @@ void wsp_ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, s
1007
1313
  }
1008
1314
  }
1009
1315
 
1316
+ void wsp_ggml_gemm_q8_0_4x4_q8_0_generic(int n,
1317
+ float * WSP_GGML_RESTRICT s,
1318
+ size_t bs,
1319
+ const void * WSP_GGML_RESTRICT vx,
1320
+ const void * WSP_GGML_RESTRICT vy,
1321
+ int nr,
1322
+ int nc) {
1323
+ const int qk = QK8_0;
1324
+ const int nb = n / qk;
1325
+ const int ncols_interleaved = 4;
1326
+ const int blocklen = 4;
1327
+
1328
+ assert(n % qk == 0);
1329
+ assert(nr % 4 == 0);
1330
+ assert(nc % ncols_interleaved == 0);
1331
+
1332
+ float sumf[4][4];
1333
+ int sumi;
1334
+
1335
+ for (int y = 0; y < nr / 4; y++) {
1336
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1337
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1338
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1339
+ for (int m = 0; m < 4; m++) {
1340
+ for (int j = 0; j < ncols_interleaved; j++) {
1341
+ sumf[m][j] = 0.0;
1342
+ }
1343
+ }
1344
+ for (int l = 0; l < nb; l++) {
1345
+ for (int k = 0; k < (qk / blocklen); k++) {
1346
+ for (int m = 0; m < 4; m++) {
1347
+ for (int j = 0; j < ncols_interleaved; j++) {
1348
+ sumi = 0;
1349
+ for (int i = 0; i < blocklen; ++i) {
1350
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1351
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1352
+ }
1353
+ sumf[m][j] +=
1354
+ sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1355
+ }
1356
+ }
1357
+ }
1358
+ }
1359
+ for (int m = 0; m < 4; m++) {
1360
+ for (int j = 0; j < ncols_interleaved; j++) {
1361
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1362
+ }
1363
+ }
1364
+ }
1365
+ }
1366
+ }
1367
+
1368
+ void wsp_ggml_gemm_q8_0_4x8_q8_0_generic(int n,
1369
+ float * WSP_GGML_RESTRICT s,
1370
+ size_t bs,
1371
+ const void * WSP_GGML_RESTRICT vx,
1372
+ const void * WSP_GGML_RESTRICT vy,
1373
+ int nr,
1374
+ int nc) {
1375
+ const int qk = QK8_0;
1376
+ const int nb = n / qk;
1377
+ const int ncols_interleaved = 4;
1378
+ const int blocklen = 8;
1379
+
1380
+ assert(n % qk == 0);
1381
+ assert(nr % 4 == 0);
1382
+ assert(nc % ncols_interleaved == 0);
1383
+
1384
+ float sumf[4][4];
1385
+ int sumi;
1386
+
1387
+ for (int y = 0; y < nr / 4; y++) {
1388
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1389
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1390
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1391
+ for (int m = 0; m < 4; m++) {
1392
+ for (int j = 0; j < ncols_interleaved; j++) {
1393
+ sumf[m][j] = 0.0;
1394
+ }
1395
+ }
1396
+ for (int l = 0; l < nb; l++) {
1397
+ for (int k = 0; k < (qk / blocklen); k++) {
1398
+ for (int m = 0; m < 4; m++) {
1399
+ for (int j = 0; j < ncols_interleaved; j++) {
1400
+ sumi = 0;
1401
+ for (int i = 0; i < blocklen; ++i) {
1402
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1403
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
1404
+ }
1405
+ sumf[m][j] +=
1406
+ sumi * WSP_GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1407
+ }
1408
+ }
1409
+ }
1410
+ }
1411
+ for (int m = 0; m < 4; m++) {
1412
+ for (int j = 0; j < ncols_interleaved; j++) {
1413
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1414
+ }
1415
+ }
1416
+ }
1417
+ }
1418
+ }
1419
+
1010
1420
  } // extern "C"
1011
1421
 
1422
+ static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
1423
+ block_q8_0x4 out;
1424
+
1425
+ for (int i = 0; i < 4; i++) {
1426
+ out.d[i] = in[i].d;
1427
+ }
1428
+
1429
+ const int end = QK8_0 * 4 / blck_size_interleave;
1430
+ for (int i = 0; i < end; ++i) {
1431
+ int src_id = i % 4;
1432
+ int src_offset = (i / 4) * blck_size_interleave;
1433
+ int dst_offset = i * blck_size_interleave;
1434
+ memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
1435
+ }
1436
+ return out;
1437
+ }
1438
+
1012
1439
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
1013
1440
  block_q4_0x4 out;
1014
1441
 
@@ -1228,9 +1655,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct wsp_ggml_tensor * t, int interleave_b
1228
1655
 
1229
1656
  WSP_GGML_UNUSED(data_size);
1230
1657
  }
1658
+
1231
1659
  static int repack_q4_K_to_q4_K_8_bl(struct wsp_ggml_tensor * t, int interleave_block, const void * WSP_GGML_RESTRICT data, size_t data_size) {
1232
1660
  WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_Q4_K);
1233
- WSP_GGML_ASSERT(interleave_block == 8);
1661
+ WSP_GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
1234
1662
  constexpr int nrows_interleaved = 8;
1235
1663
 
1236
1664
  block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@@ -1321,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct wsp_ggml_tensor * t, int interleave_b
1321
1749
  WSP_GGML_UNUSED(data_size);
1322
1750
  }
1323
1751
 
1752
+ static int repack_q8_0_to_q8_0_4_bl(struct wsp_ggml_tensor * t,
1753
+ int interleave_block,
1754
+ const void * WSP_GGML_RESTRICT data,
1755
+ size_t data_size) {
1756
+ WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_Q8_0);
1757
+ WSP_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
1758
+ constexpr int nrows_interleaved = 4;
1759
+
1760
+ block_q8_0x4 * dst = (block_q8_0x4 *) t->data;
1761
+ const block_q8_0 * src = (const block_q8_0 *) data;
1762
+ block_q8_0 dst_tmp[4];
1763
+ int nrow = wsp_ggml_nrows(t);
1764
+ int nblocks = t->ne[0] / QK8_0;
1765
+
1766
+ WSP_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
1767
+
1768
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1769
+ return -1;
1770
+ }
1771
+
1772
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
1773
+ for (int64_t x = 0; x < nblocks; x++) {
1774
+ for (int i = 0; i < nrows_interleaved; i++) {
1775
+ dst_tmp[i] = src[x + i * nblocks];
1776
+ }
1777
+ *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
1778
+ }
1779
+ src += nrows_interleaved * nblocks;
1780
+ }
1781
+ return 0;
1782
+ }
1783
+
1324
1784
  static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
1325
1785
  block_iq4_nlx4 out;
1326
1786
 
@@ -1468,6 +1928,10 @@ template <> int repack<block_q4_K, 8, 8>(struct wsp_ggml_tensor * t, const void
1468
1928
  return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
1469
1929
  }
1470
1930
 
1931
+ template <> int repack<block_q4_K, 4, 8>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
1932
+ return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
1933
+ }
1934
+
1471
1935
  template <> int repack<block_q2_K, 8, 8>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
1472
1936
  return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
1473
1937
  }
@@ -1485,6 +1949,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct wsp_ggml_tensor * t, const voi
1485
1949
  return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
1486
1950
  }
1487
1951
 
1952
+ template <> int repack<block_q8_0, 4, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
1953
+ return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
1954
+ }
1955
+
1956
+ template <> int repack<block_q8_0, 8, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
1957
+ return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
1958
+ }
1959
+
1488
1960
  // gemv
1489
1961
  template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type PARAM_TYPE>
1490
1962
  void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1501,6 +1973,10 @@ template <> void gemv<block_q4_0, 8, 8, WSP_GGML_TYPE_Q8_0>(int n, float * s, si
1501
1973
  wsp_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1502
1974
  }
1503
1975
 
1976
+ template <> void gemv<block_q4_K, 4, 8, WSP_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1977
+ wsp_ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
1978
+ }
1979
+
1504
1980
  template <> void gemv<block_q4_K, 8, 8, WSP_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1505
1981
  wsp_ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1506
1982
  }
@@ -1517,6 +1993,14 @@ template <> void gemv<block_iq4_nl, 8, 8, WSP_GGML_TYPE_Q8_0>(int n, float * s,
1517
1993
  wsp_ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1518
1994
  }
1519
1995
 
1996
+ template <> void gemv<block_q8_0, 4, 4, WSP_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1997
+ wsp_ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1998
+ }
1999
+
2000
+ template <> void gemv<block_q8_0, 8, 4, WSP_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2001
+ wsp_ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2002
+ }
2003
+
1520
2004
  // gemm
1521
2005
  template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type PARAM_TYPE>
1522
2006
  void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1529,6 +2013,10 @@ template <> void gemm<block_q4_0, 8, 4, WSP_GGML_TYPE_Q8_0>(int n, float * s, si
1529
2013
  wsp_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1530
2014
  }
1531
2015
 
2016
+ template <> void gemm<block_q4_K, 4, 8, WSP_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2017
+ wsp_ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
2018
+ }
2019
+
1532
2020
  template <> void gemm<block_q4_0, 8, 8, WSP_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1533
2021
  wsp_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1534
2022
  }
@@ -1549,6 +2037,14 @@ template <> void gemm<block_iq4_nl, 8, 8, WSP_GGML_TYPE_Q8_0>(int n, float * s,
1549
2037
  wsp_ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1550
2038
  }
1551
2039
 
2040
+ template <> void gemm<block_q8_0, 4, 4, WSP_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2041
+ wsp_ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
2042
+ }
2043
+
2044
+ template <> void gemm<block_q8_0, 8, 4, WSP_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
2045
+ wsp_ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
2046
+ }
2047
+
1552
2048
  class tensor_traits_base : public ggml::cpu::tensor_traits {
1553
2049
  public:
1554
2050
  virtual int repack(struct wsp_ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -1731,12 +2227,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, wsp_ggml_type
1731
2227
  nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1732
2228
  }
1733
2229
 
1734
- if (nth == 1 || nchunk0 < nth || disable_chunking) {
2230
+ int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
2231
+ // Only increase nchunk0 to nth if it won't make chunks too small
2232
+ if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
1735
2233
  nchunk0 = nth;
2234
+ dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1736
2235
  }
1737
2236
 
1738
- const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739
-
1740
2237
  // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1741
2238
  // This prevents creating too many tiny chunks that could overlap after alignment
1742
2239
  const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
@@ -1930,6 +2427,9 @@ static const ggml::cpu::tensor_traits * wsp_ggml_repack_get_optimal_repack_type(
1930
2427
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, WSP_GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1931
2428
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, WSP_GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1932
2429
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, WSP_GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
2430
+
2431
+ // instance for Q4_K
2432
+ static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, WSP_GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
1933
2433
  static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, WSP_GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1934
2434
 
1935
2435
  // instance for Q2
@@ -1939,8 +2439,13 @@ static const ggml::cpu::tensor_traits * wsp_ggml_repack_get_optimal_repack_type(
1939
2439
  static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, WSP_GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1940
2440
  static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, WSP_GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
1941
2441
 
2442
+ // instance for Q8_0
2443
+ static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, WSP_GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
2444
+ static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, WSP_GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
2445
+
1942
2446
  if (cur->type == WSP_GGML_TYPE_Q4_0) {
1943
- if (wsp_ggml_cpu_has_avx2() || (wsp_ggml_cpu_has_sve() && wsp_ggml_cpu_has_matmul_int8() && wsp_ggml_cpu_get_sve_cnt() == QK8_0)) {
2447
+ if (wsp_ggml_cpu_has_avx2() || (wsp_ggml_cpu_has_sve() && wsp_ggml_cpu_has_matmul_int8() && wsp_ggml_cpu_get_sve_cnt() == QK8_0)
2448
+ || (wsp_ggml_cpu_has_riscv_v() && (wsp_ggml_cpu_get_rvv_vlen() >= QK4_0))) {
1944
2449
  if (cur->ne[1] % 8 == 0) {
1945
2450
  return &q4_0_8x8_q8_0;
1946
2451
  }
@@ -1961,6 +2466,16 @@ static const ggml::cpu::tensor_traits * wsp_ggml_repack_get_optimal_repack_type(
1961
2466
  return &q4_K_8x8_q8_K;
1962
2467
  }
1963
2468
  }
2469
+ if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_matmul_int8()) {
2470
+ if (cur->ne[1] % 8 == 0) {
2471
+ return &q4_K_8x8_q8_K;
2472
+ }
2473
+ }
2474
+ if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
2475
+ if (cur->ne[1] % 8 == 0) {
2476
+ return &q4_K_8x4_q8_K;
2477
+ }
2478
+ }
1964
2479
  } else if (cur->type == WSP_GGML_TYPE_Q2_K) {
1965
2480
  if (wsp_ggml_cpu_has_avx512()) {
1966
2481
  if (cur->ne[1] % 8 == 0) {
@@ -1978,6 +2493,17 @@ static const ggml::cpu::tensor_traits * wsp_ggml_repack_get_optimal_repack_type(
1978
2493
  return &iq4_nl_4x4_q8_0;
1979
2494
  }
1980
2495
  }
2496
+ } else if (cur->type == WSP_GGML_TYPE_Q8_0) {
2497
+ if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_matmul_int8()) {
2498
+ if (cur->ne[1] % 4 == 0) {
2499
+ return &q8_0_4x8_q8_0;
2500
+ }
2501
+ }
2502
+ if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
2503
+ if (cur->ne[1] % 4 == 0) {
2504
+ return &q8_0_4x4_q8_0;
2505
+ }
2506
+ }
1981
2507
  }
1982
2508
 
1983
2509
  return nullptr;