@fugood/llama.node 1.3.7 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/lib/binding.js +18 -1
  2. package/lib/binding.ts +19 -1
  3. package/lib/index.js +3 -3
  4. package/lib/index.ts +1 -1
  5. package/package.json +15 -15
  6. package/scripts/llama.cpp.patch +7 -7
  7. package/src/LlamaCompletionWorker.cpp +2 -2
  8. package/src/llama.cpp/common/arg.cpp +27 -2
  9. package/src/llama.cpp/common/chat-parser.cpp +968 -0
  10. package/src/llama.cpp/common/chat.cpp +0 -952
  11. package/src/llama.cpp/common/common.cpp +55 -0
  12. package/src/llama.cpp/common/common.h +18 -0
  13. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -2
  14. package/src/llama.cpp/ggml/CMakeLists.txt +6 -4
  15. package/src/llama.cpp/ggml/include/ggml-rpc.h +1 -1
  16. package/src/llama.cpp/ggml/include/ggml.h +12 -4
  17. package/src/llama.cpp/ggml/src/CMakeLists.txt +26 -4
  18. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +29 -15
  19. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +721 -0
  20. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  21. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +22 -2
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +9 -0
  23. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +71 -4
  24. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  25. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +243 -4
  26. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +6 -0
  27. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +84 -85
  28. package/src/llama.cpp/include/llama.h +18 -0
  29. package/src/llama.cpp/src/CMakeLists.txt +2 -0
  30. package/src/llama.cpp/src/llama-arch.cpp +95 -16
  31. package/src/llama.cpp/src/llama-arch.h +15 -0
  32. package/src/llama.cpp/src/llama-context.cpp +7 -3
  33. package/src/llama.cpp/src/llama-graph.cpp +3 -3
  34. package/src/llama.cpp/src/llama-hparams.h +1 -1
  35. package/src/llama.cpp/src/llama-model.cpp +141 -6
  36. package/src/llama.cpp/src/llama-model.h +4 -0
  37. package/src/llama.cpp/src/llama-quant.cpp +13 -5
  38. package/src/llama.cpp/src/models/lfm2.cpp +5 -3
  39. package/src/llama.cpp/src/models/models.h +55 -1
  40. package/src/llama.cpp/src/models/qwen3next.cpp +1042 -0
  41. package/src/llama.cpp/src/models/rnd1.cpp +126 -0
@@ -0,0 +1,38 @@
1
+ #include "ggml-backend-impl.h"
2
+
3
+ #if defined(__riscv) && __riscv_xlen == 64
4
+ #include <asm/hwprobe.h>
5
+ #include <asm/unistd.h>
6
+ #include <unistd.h>
7
+
8
+ struct riscv64_features {
9
+ bool has_rvv = false;
10
+
11
+ riscv64_features() {
12
+ struct riscv_hwprobe probe;
13
+ probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
14
+ probe.value = 0;
15
+
16
+ int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
17
+
18
+ if (0 == ret) {
19
+ has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
20
+ }
21
+ }
22
+ };
23
+
24
+ static int ggml_backend_cpu_riscv64_score() {
25
+ int score = 1;
26
+ riscv64_features rf;
27
+
28
+ #ifdef GGML_USE_RVV
29
+ if (!rf.has_rvv) { return 0; }
30
+ score += 1 << 1;
31
+ #endif
32
+
33
+ return score;
34
+ }
35
+
36
+ GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_riscv64_score)
37
+
38
+ #endif // __riscv && __riscv_xlen == 64
@@ -33,10 +33,12 @@
33
33
  // repack.cpp
34
34
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
35
35
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
36
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
36
37
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
37
38
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
38
39
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
39
40
  #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
41
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
40
42
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
41
43
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
42
44
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -44,27 +46,30 @@
44
46
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
45
47
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
46
48
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
49
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
47
50
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
48
51
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
49
52
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
50
53
  #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
51
54
  #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
52
55
  // repack.cpp
56
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
53
57
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
54
- #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
55
58
  #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
56
59
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
57
- #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
58
60
  #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
59
61
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
60
62
  #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
61
63
  // repack.cpp
62
64
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
65
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
63
66
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
64
67
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
68
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
65
69
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
66
70
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
67
71
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
72
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
68
73
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
69
74
  #elif defined(__POWERPC__) || defined(__powerpc__)
70
75
  // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
@@ -76,10 +81,12 @@
76
81
  // repack.cpp
77
82
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
78
83
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
84
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
79
85
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
80
86
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
81
87
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
82
88
  #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
89
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
83
90
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
84
91
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
85
92
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -87,6 +94,7 @@
87
94
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
88
95
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
89
96
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
97
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
90
98
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
91
99
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
92
100
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -101,10 +109,12 @@
101
109
  // repack.cpp
102
110
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
103
111
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
112
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
104
113
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
105
114
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
106
115
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
107
116
  #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
117
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
108
118
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
109
119
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
110
120
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -112,6 +122,7 @@
112
122
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
113
123
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
114
124
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
125
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
115
126
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
116
127
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
117
128
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -134,15 +145,18 @@
134
145
  // repack.cpp
135
146
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
136
147
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
148
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
137
149
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
138
150
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
139
151
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
152
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
140
153
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
141
154
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
142
155
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
143
156
  #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
144
157
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
145
158
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
159
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
146
160
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
147
161
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
148
162
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -163,10 +177,12 @@
163
177
  // repack.cpp
164
178
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
165
179
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
180
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
166
181
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
167
182
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
168
183
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
169
184
  #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
185
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
170
186
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
171
187
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
172
188
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -174,6 +190,7 @@
174
190
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
175
191
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
176
192
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
193
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
177
194
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
178
195
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
179
196
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -196,10 +213,12 @@
196
213
  // repack.cpp
197
214
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
198
215
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
216
+ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
199
217
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
200
218
  #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
201
219
  #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
202
220
  #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
221
+ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
203
222
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
204
223
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
205
224
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@@ -207,6 +226,7 @@
207
226
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
208
227
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
209
228
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
229
+ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
210
230
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
211
231
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
212
232
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1927
1927
  {
1928
1928
  ggml_compute_forward_argsort(params, tensor);
1929
1929
  } break;
1930
+ case GGML_OP_TOP_K:
1931
+ {
1932
+ ggml_compute_forward_top_k(params, tensor);
1933
+ } break;
1930
1934
  case GGML_OP_LEAKY_RELU:
1931
1935
  {
1932
1936
  ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2311
2315
  case GGML_OP_ARANGE:
2312
2316
  case GGML_OP_TIMESTEP_EMBEDDING:
2313
2317
  case GGML_OP_ARGSORT:
2318
+ case GGML_OP_TOP_K:
2314
2319
  case GGML_OP_FLASH_ATTN_EXT:
2315
2320
  case GGML_OP_FLASH_ATTN_BACK:
2316
2321
  case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
2834
2839
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
2835
2840
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
2836
2841
  } break;
2842
+ case GGML_OP_TOP_K:
2843
+ {
2844
+ cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
2845
+ } break;
2837
2846
  case GGML_OP_FLASH_ATTN_EXT:
2838
2847
  {
2839
2848
  const int64_t ne10 = node->src[1]->ne[0]; // DK
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
7794
7794
  // ggml_compute_forward_argsort
7795
7795
 
7796
7796
  template<enum ggml_sort_order order>
7797
- struct argsort_cmp {
7797
+ struct cmp_argsort {
7798
7798
  const float * data;
7799
7799
  bool operator()(int32_t a, int32_t b) const {
7800
7800
  if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
7833
7833
 
7834
7834
  switch (order) {
7835
7835
  case GGML_SORT_ORDER_ASC:
7836
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7836
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7837
7837
  break;
7838
7838
 
7839
7839
  case GGML_SORT_ORDER_DESC:
7840
- std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7840
+ std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7841
7841
  break;
7842
7842
 
7843
7843
  default:
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
7864
7864
  }
7865
7865
  }
7866
7866
 
7867
+ // ggml_compute_forward_top_k
7868
+
7869
+ struct cmp_top_k {
7870
+ const float * data;
7871
+ bool operator()(int32_t a, int32_t b) const {
7872
+ return data[a] > data[b];
7873
+ }
7874
+ };
7875
+
7876
+ static void ggml_compute_forward_top_k_f32(
7877
+ const ggml_compute_params * params,
7878
+ ggml_tensor * dst) {
7879
+
7880
+ const ggml_tensor * src0 = dst->src[0];
7881
+
7882
+ GGML_TENSOR_UNARY_OP_LOCALS
7883
+
7884
+ GGML_ASSERT(nb0 == sizeof(float));
7885
+
7886
+ const int ith = params->ith;
7887
+ const int nth = params->nth;
7888
+
7889
+ const int64_t nr = ggml_nrows(src0);
7890
+
7891
+ const int top_k = ne0;
7892
+
7893
+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7894
+
7895
+ for (int64_t i = ith; i < nr; i += nth) {
7896
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7897
+
7898
+ for (int64_t j = 0; j < ne00; j++) {
7899
+ tmp[j] = j;
7900
+ }
7901
+
7902
+ std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7903
+
7904
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7905
+
7906
+ std::copy(tmp, tmp + top_k, dst_data);
7907
+
7908
+ // emphasize that the order is not important
7909
+ if (top_k > 1) {
7910
+ std::swap(dst_data[0], dst_data[1]);
7911
+ }
7912
+ }
7913
+ }
7914
+
7915
+ void ggml_compute_forward_top_k(
7916
+ const ggml_compute_params * params,
7917
+ ggml_tensor * dst) {
7918
+
7919
+ const ggml_tensor * src0 = dst->src[0];
7920
+
7921
+ switch (src0->type) {
7922
+ case GGML_TYPE_F32:
7923
+ {
7924
+ ggml_compute_forward_top_k_f32(params, dst);
7925
+ } break;
7926
+ default:
7927
+ {
7928
+ GGML_ABORT("fatal error");
7929
+ }
7930
+ }
7931
+ }
7932
+
7867
7933
  // ggml_compute_forward_flash_attn_ext
7868
7934
 
7869
7935
  static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
@@ -9700,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
9700
9766
  }
9701
9767
 
9702
9768
  const float diag = A_batch[i00 * n + i00];
9703
- GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
9769
+ assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9770
+
9704
9771
  X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9705
9772
  }
9706
9773
  }
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
81
81
  void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
82
82
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
83
83
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
+ void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
85
  void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
85
86
  void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
86
87
  void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
124
124
  }
125
125
  }
126
126
 
127
+
128
+ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
129
+ assert(QK_K == 256);
130
+ assert(k % QK_K == 0);
131
+ const int nb = k / QK_K;
132
+
133
+ block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
134
+
135
+ // scalar
136
+ const int blck_size_interleave = 4;
137
+ float srcv[4][QK_K];
138
+ float iscale[4];
139
+
140
+ for (int i = 0; i < nb; i++) {
141
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
142
+ float amax = 0.0f; // absolute max
143
+ float max = 0;
144
+
145
+ for (int j = 0; j < QK_K; j++) {
146
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
147
+ // Update the maximum value of the corresponding super block
148
+ if(amax < fabsf(srcv[row_iter][j])) {
149
+ amax = fabsf(srcv[row_iter][j]);
150
+ max = srcv[row_iter][j];
151
+ }
152
+ }
153
+
154
+ iscale[row_iter] = amax ? -127.f/max : 0;
155
+
156
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
157
+ }
158
+
159
+ for (int j = 0; j < QK_K / 4; j++) {
160
+ y[i].bsums[j] = 0;
161
+ }
162
+
163
+ // Quants values are interleaved in sequence of four bytes from corresponding super blocks
164
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
165
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
166
+ for (int j = 0; j < QK_K * 4; j++) {
167
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
168
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
169
+ src_offset += (j % blck_size_interleave);
170
+ int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
171
+
172
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
173
+ y[i].qs[j] = nearest_int(x0);
174
+ y[i].bsums[index] += y[i].qs[j];
175
+ }
176
+ }
177
+ }
178
+
127
179
  void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
128
180
  assert(QK_K == 256);
129
181
  assert(k % QK_K == 0);
@@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
192
244
  ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
193
245
  }
194
246
 
247
+ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
248
+ assert(nrow == 4);
249
+ UNUSED(nrow);
250
+ ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
251
+ }
252
+
195
253
  template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
196
254
  assert(nrow == 4);
197
255
  UNUSED(nrow);
@@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
333
391
  }
334
392
  }
335
393
 
394
+ void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
395
+ const int qk = QK_K;
396
+ const int nb = n / qk;
397
+ const int ncols_interleaved = 8;
398
+ const int blocklen = 4;
399
+ static const uint32_t kmask1 = 0x3f3f3f3f;
400
+ static const uint32_t kmask2 = 0x0f0f0f0f;
401
+ static const uint32_t kmask3 = 0x03030303;
402
+
403
+ assert (n % qk == 0);
404
+ assert (nc % ncols_interleaved == 0);
405
+
406
+ UNUSED(bs);
407
+ UNUSED(nr);
408
+
409
+ float sumf[8];
410
+ float sum_minf[8];
411
+ uint32_t utmp[32];
412
+ int sumi1;
413
+ int sumi2;
414
+ int sumi;
415
+
416
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
417
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
418
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
419
+
420
+ for (int j = 0; j < ncols_interleaved; j++) {
421
+ sumf[j] = 0.0;
422
+ sum_minf[j] = 0.0;
423
+ }
424
+ for (int l = 0; l < nb; l++) {
425
+ for (int sb = 0; sb < 8; sb++) {
426
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
427
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
428
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
429
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
430
+ utmp[sb * 4 + 2] = uaux_0;
431
+ utmp[sb * 4 + 0] &= kmask1;
432
+ }
433
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
434
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
435
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
436
+ for (int j = 0; j < ncols_interleaved; j++) {
437
+ sumi1 = 0;
438
+ sumi2 = 0;
439
+ sumi = 0;
440
+ for (int i = 0; i < blocklen; ++i) {
441
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
442
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
443
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
444
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
445
+ sumi1 = sumi1 * scales_0[j];
446
+ sumi2 = sumi2 * scales_1[j];
447
+ sumi += sumi1 + sumi2;
448
+ }
449
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
450
+ }
451
+ }
452
+ for (int sb = 0; sb < 8; sb++) {
453
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
454
+ for (int j = 0; j < ncols_interleaved; j++) {
455
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
456
+ }
457
+ }
458
+ }
459
+ for (int j = 0; j < ncols_interleaved; j++) {
460
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
461
+ }
462
+ }
463
+ }
464
+
336
465
  void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
337
466
  const int qk = QK_K;
338
467
  const int nb = n / qk;
@@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
727
856
  }
728
857
  }
729
858
 
859
+ void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
860
+ const int qk = QK_K;
861
+ const int nb = n / qk;
862
+ const int ncols_interleaved = 8;
863
+ const int blocklen = 4;
864
+ static const uint32_t kmask1 = 0x3f3f3f3f;
865
+ static const uint32_t kmask2 = 0x0f0f0f0f;
866
+ static const uint32_t kmask3 = 0x03030303;
867
+
868
+ assert (n % qk == 0);
869
+ assert (nr % 4 == 0);
870
+ assert (nc % ncols_interleaved == 0);
871
+
872
+ UNUSED(nb);
873
+ UNUSED(ncols_interleaved);
874
+ UNUSED(blocklen);
875
+
876
+ float sumf[4][8];
877
+ float sum_minf[4][8];
878
+ uint32_t utmp[32];
879
+ int sumi1;
880
+ int sumi2;
881
+ int sumi;
882
+
883
+ for (int y = 0; y < nr / 4; y++) {
884
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
885
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
886
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
887
+ for (int m = 0; m < 4; m++) {
888
+ for (int j = 0; j < ncols_interleaved; j++) {
889
+ sumf[m][j] = 0.0;
890
+ sum_minf[m][j] = 0.0;
891
+ }
892
+ }
893
+ for (int l = 0; l < nb; l++) {
894
+ for (int sb = 0; sb < 8; sb++) {
895
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
896
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
897
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
898
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
899
+ utmp[sb * 4 + 2] = uaux_0;
900
+ utmp[sb * 4 + 0] &= kmask1;
901
+ }
902
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
903
+ uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
904
+ uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
905
+ for (int m = 0; m < 4; m++) {
906
+ for (int j = 0; j < ncols_interleaved; j++) {
907
+ sumi1 = 0;
908
+ sumi2 = 0;
909
+ sumi = 0;
910
+ for (int i = 0; i < blocklen; ++i) {
911
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
912
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
913
+ sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
914
+ sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
915
+ sumi1 = sumi1 * scales_0[j];
916
+ sumi2 = sumi2 * scales_1[j];
917
+ sumi += sumi1 + sumi2;
918
+ }
919
+ sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
920
+ }
921
+ }
922
+ }
923
+ for (int sb = 0; sb < 8; sb++) {
924
+ uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
925
+ for(int m = 0; m < 4; m++) {
926
+ const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
927
+ for(int j = 0; j < ncols_interleaved; j++) {
928
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
929
+ }
930
+ }
931
+ }
932
+ }
933
+ for (int m = 0; m < 4; m++) {
934
+ for (int j = 0; j < ncols_interleaved; j++) {
935
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
936
+ }
937
+ }
938
+ }
939
+ }
940
+ }
941
+
730
942
  void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
731
943
  const int qk = QK_K;
732
944
  const int nb = n / qk;
@@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
1228
1440
 
1229
1441
  GGML_UNUSED(data_size);
1230
1442
  }
1443
+
1231
1444
  static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1232
1445
  GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
1233
- GGML_ASSERT(interleave_block == 8);
1446
+ GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
1234
1447
  constexpr int nrows_interleaved = 8;
1235
1448
 
1236
1449
  block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
1468
1681
  return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
1469
1682
  }
1470
1683
 
1684
+ template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1685
+ return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
1686
+ }
1687
+
1471
1688
  template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1472
1689
  return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
1473
1690
  }
@@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
1501
1718
  ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1502
1719
  }
1503
1720
 
1721
+ template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1722
+ ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
1723
+ }
1724
+
1504
1725
  template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1505
1726
  ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1506
1727
  }
@@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
1529
1750
  ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1530
1751
  }
1531
1752
 
1753
+ template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1754
+ ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
1755
+ }
1756
+
1532
1757
  template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1533
1758
  ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1534
1759
  }
@@ -1731,12 +1956,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1731
1956
  nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
1732
1957
  }
1733
1958
 
1734
- if (nth == 1 || nchunk0 < nth || disable_chunking) {
1959
+ int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1960
+ // Only increase nchunk0 to nth if it won't make chunks too small
1961
+ if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
1735
1962
  nchunk0 = nth;
1963
+ dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1736
1964
  }
1737
1965
 
1738
- const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739
-
1740
1966
  // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1741
1967
  // This prevents creating too many tiny chunks that could overlap after alignment
1742
1968
  const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
@@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1930
2156
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1931
2157
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1932
2158
  static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
2159
+
2160
+ // instance for Q4_K
2161
+ static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
1933
2162
  static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1934
2163
 
1935
2164
  // instance for Q2
@@ -1961,6 +2190,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
1961
2190
  return &q4_K_8x8_q8_K;
1962
2191
  }
1963
2192
  }
2193
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
2194
+ if (cur->ne[1] % 8 == 0) {
2195
+ return &q4_K_8x8_q8_K;
2196
+ }
2197
+ }
2198
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
2199
+ if (cur->ne[1] % 8 == 0) {
2200
+ return &q4_K_8x4_q8_K;
2201
+ }
2202
+ }
1964
2203
  } else if (cur->type == GGML_TYPE_Q2_K) {
1965
2204
  if (ggml_cpu_has_avx512()) {
1966
2205
  if (cur->ne[1] % 8 == 0) {