llama_cpp 0.14.5 → 0.14.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,12 @@
11
11
  #include <string.h> // memcpy
12
12
  #include <math.h> // fabsf
13
13
 
14
+ #undef MIN
15
+ #undef MAX
16
+
17
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
+
14
20
  #ifdef __cplusplus
15
21
  extern "C" {
16
22
  #endif
@@ -45,7 +51,7 @@ extern "C" {
45
51
  // 16-bit float
46
52
  // on Arm, we use __fp16
47
53
  // on x86, we use uint16_t
48
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
54
+ #if defined(__ARM_NEON)
49
55
 
50
56
  // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
51
57
  //
@@ -53,8 +59,262 @@ extern "C" {
53
59
  //
54
60
  #include <arm_neon.h>
55
61
 
62
+ #ifdef _MSC_VER
63
+
64
+ typedef uint16_t ggml_fp16_internal_t;
65
+
66
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
67
+
68
+ #else
69
+
56
70
  typedef __fp16 ggml_fp16_internal_t;
57
71
 
72
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
73
+
74
+ #endif // _MSC_VER
75
+
76
+ #if !defined(__aarch64__)
77
+
78
+ // 32-bit ARM compatibility
79
+
80
+ // vaddvq_s16
81
+ // vpaddq_s16
82
+ // vpaddq_s32
83
+ // vaddvq_s32
84
+ // vaddvq_f32
85
+ // vmaxvq_f32
86
+ // vcvtnq_s32_f32
87
+ // vzip1_u8
88
+ // vzip2_u8
89
+
90
+ inline static int32_t vaddvq_s16(int16x8_t v) {
91
+ return
92
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
93
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
94
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
95
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
96
+ }
97
+
98
+ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
99
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
100
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
101
+ return vcombine_s16(a0, b0);
102
+ }
103
+
104
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
105
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
106
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
107
+ return vcombine_s32(a0, b0);
108
+ }
109
+
110
+ inline static int32_t vaddvq_s32(int32x4_t v) {
111
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
112
+ }
113
+
114
+ inline static float vaddvq_f32(float32x4_t v) {
115
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
116
+ }
117
+
118
+ inline static float vmaxvq_f32(float32x4_t v) {
119
+ return
120
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
121
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
122
+ }
123
+
124
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
125
+ int32x4_t res;
126
+
127
+ res[0] = roundf(vgetq_lane_f32(v, 0));
128
+ res[1] = roundf(vgetq_lane_f32(v, 1));
129
+ res[2] = roundf(vgetq_lane_f32(v, 2));
130
+ res[3] = roundf(vgetq_lane_f32(v, 3));
131
+
132
+ return res;
133
+ }
134
+
135
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
136
+ uint8x8_t res;
137
+
138
+ res[0] = a[0]; res[1] = b[0];
139
+ res[2] = a[1]; res[3] = b[1];
140
+ res[4] = a[2]; res[5] = b[2];
141
+ res[6] = a[3]; res[7] = b[3];
142
+
143
+ return res;
144
+ }
145
+
146
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
147
+ uint8x8_t res;
148
+
149
+ res[0] = a[4]; res[1] = b[4];
150
+ res[2] = a[5]; res[3] = b[5];
151
+ res[4] = a[6]; res[5] = b[6];
152
+ res[6] = a[7]; res[7] = b[7];
153
+
154
+ return res;
155
+ }
156
+
157
+ // vld1q_s16_x2
158
+ // vld1q_u8_x2
159
+ // vld1q_u8_x4
160
+ // vld1q_s8_x2
161
+ // vld1q_s8_x4
162
+ // TODO: double-check these work correctly
163
+
164
+ typedef struct ggml_int16x8x2_t {
165
+ int16x8_t val[2];
166
+ } ggml_int16x8x2_t;
167
+
168
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
169
+ ggml_int16x8x2_t res;
170
+
171
+ res.val[0] = vld1q_s16(ptr + 0);
172
+ res.val[1] = vld1q_s16(ptr + 8);
173
+
174
+ return res;
175
+ }
176
+
177
+ typedef struct ggml_uint8x16x2_t {
178
+ uint8x16_t val[2];
179
+ } ggml_uint8x16x2_t;
180
+
181
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
182
+ ggml_uint8x16x2_t res;
183
+
184
+ res.val[0] = vld1q_u8(ptr + 0);
185
+ res.val[1] = vld1q_u8(ptr + 16);
186
+
187
+ return res;
188
+ }
189
+
190
+ typedef struct ggml_uint8x16x4_t {
191
+ uint8x16_t val[4];
192
+ } ggml_uint8x16x4_t;
193
+
194
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
195
+ ggml_uint8x16x4_t res;
196
+
197
+ res.val[0] = vld1q_u8(ptr + 0);
198
+ res.val[1] = vld1q_u8(ptr + 16);
199
+ res.val[2] = vld1q_u8(ptr + 32);
200
+ res.val[3] = vld1q_u8(ptr + 48);
201
+
202
+ return res;
203
+ }
204
+
205
+ typedef struct ggml_int8x16x2_t {
206
+ int8x16_t val[2];
207
+ } ggml_int8x16x2_t;
208
+
209
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
210
+ ggml_int8x16x2_t res;
211
+
212
+ res.val[0] = vld1q_s8(ptr + 0);
213
+ res.val[1] = vld1q_s8(ptr + 16);
214
+
215
+ return res;
216
+ }
217
+
218
+ typedef struct ggml_int8x16x4_t {
219
+ int8x16_t val[4];
220
+ } ggml_int8x16x4_t;
221
+
222
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
223
+ ggml_int8x16x4_t res;
224
+
225
+ res.val[0] = vld1q_s8(ptr + 0);
226
+ res.val[1] = vld1q_s8(ptr + 16);
227
+ res.val[2] = vld1q_s8(ptr + 32);
228
+ res.val[3] = vld1q_s8(ptr + 48);
229
+
230
+ return res;
231
+ }
232
+
233
+ // NOTE: not tested
234
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
235
+ int8x16_t res;
236
+
237
+ res[ 0] = a[b[ 0]];
238
+ res[ 1] = a[b[ 1]];
239
+ res[ 2] = a[b[ 2]];
240
+ res[ 3] = a[b[ 3]];
241
+ res[ 4] = a[b[ 4]];
242
+ res[ 5] = a[b[ 5]];
243
+ res[ 6] = a[b[ 6]];
244
+ res[ 7] = a[b[ 7]];
245
+ res[ 8] = a[b[ 8]];
246
+ res[ 9] = a[b[ 9]];
247
+ res[10] = a[b[10]];
248
+ res[11] = a[b[11]];
249
+ res[12] = a[b[12]];
250
+ res[13] = a[b[13]];
251
+ res[14] = a[b[14]];
252
+ res[15] = a[b[15]];
253
+
254
+ return res;
255
+ }
256
+
257
+ // NOTE: not tested
258
+ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
259
+ uint8x16_t res;
260
+
261
+ res[ 0] = a[b[ 0]];
262
+ res[ 1] = a[b[ 1]];
263
+ res[ 2] = a[b[ 2]];
264
+ res[ 3] = a[b[ 3]];
265
+ res[ 4] = a[b[ 4]];
266
+ res[ 5] = a[b[ 5]];
267
+ res[ 6] = a[b[ 6]];
268
+ res[ 7] = a[b[ 7]];
269
+ res[ 8] = a[b[ 8]];
270
+ res[ 9] = a[b[ 9]];
271
+ res[10] = a[b[10]];
272
+ res[11] = a[b[11]];
273
+ res[12] = a[b[12]];
274
+ res[13] = a[b[13]];
275
+ res[14] = a[b[14]];
276
+ res[15] = a[b[15]];
277
+
278
+ return res;
279
+ }
280
+
281
+ #else
282
+
283
+ #define ggml_int16x8x2_t int16x8x2_t
284
+ #define ggml_uint8x16x2_t uint8x16x2_t
285
+ #define ggml_uint8x16x4_t uint8x16x4_t
286
+ #define ggml_int8x16x2_t int8x16x2_t
287
+ #define ggml_int8x16x4_t int8x16x4_t
288
+
289
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
290
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
291
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
292
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
293
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
294
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
295
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
296
+
297
+ #endif // !defined(__aarch64__)
298
+
299
+ #if !defined(__ARM_FEATURE_DOTPROD)
300
+
301
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
302
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
303
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
304
+
305
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
306
+ }
307
+
308
+ #else
309
+
310
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
311
+
312
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
313
+
314
+ #endif // defined(__ARM_NEON)
315
+
316
+ #if defined(__ARM_NEON) && !defined(__MSC_VER)
317
+
58
318
  #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
59
319
  #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
60
320
 
@@ -75,8 +335,6 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
75
335
 
76
336
  #else
77
337
 
78
- typedef uint16_t ggml_fp16_internal_t;
79
-
80
338
  #ifdef __wasm_simd128__
81
339
  #include <wasm_simd128.h>
82
340
  #else
@@ -88,7 +346,7 @@ typedef uint16_t ggml_fp16_internal_t;
88
346
  #if defined(_MSC_VER) || defined(__MINGW32__)
89
347
  #include <intrin.h>
90
348
  #else
91
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
349
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
92
350
  #if !defined(__riscv)
93
351
  #include <immintrin.h>
94
352
  #endif
@@ -221,7 +479,7 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
221
479
 
222
480
  #endif // __F16C__
223
481
 
224
- #endif // __ARM_NEON
482
+ #endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
225
483
 
226
484
  // precomputed f32 table for f16 (256 KB)
227
485
  // defined in ggml.c, initialized in ggml_init()
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
+ GGML_METAL_KERNEL_TYPE_CLAMP,
40
41
  GGML_METAL_KERNEL_TYPE_TANH,
41
42
  GGML_METAL_KERNEL_TYPE_RELU,
42
43
  GGML_METAL_KERNEL_TYPE_GELU,
44
+ GGML_METAL_KERNEL_TYPE_GELU_4,
43
45
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
44
47
  GGML_METAL_KERNEL_TYPE_SILU,
48
+ GGML_METAL_KERNEL_TYPE_SILU_4,
45
49
  GGML_METAL_KERNEL_TYPE_SOFT_MAX,
46
50
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
47
51
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
468
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
469
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
470
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
471
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
472
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
473
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
474
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
475
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
476
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
477
485
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
478
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
713
721
  case GGML_OP_MUL:
714
722
  case GGML_OP_DIV:
715
723
  case GGML_OP_SCALE:
724
+ case GGML_OP_CLAMP:
716
725
  case GGML_OP_SQR:
717
726
  case GGML_OP_SUM_ROWS:
718
727
  return true;
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
1154
1163
 
1155
1164
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1156
1165
  } break;
1166
+ case GGML_OP_CLAMP:
1167
+ {
1168
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1169
+
1170
+ float min;
1171
+ float max;
1172
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1173
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1174
+
1175
+ [encoder setComputePipelineState:pipeline];
1176
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1177
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1178
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1179
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1180
+
1181
+ const int64_t n = ggml_nelements(dst);
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1157
1185
  case GGML_OP_UNARY:
1158
1186
  switch (ggml_get_unary_op(gf->nodes[i])) {
1187
+ // we are not taking into account the strides, so for now require contiguous tensors
1188
+ GGML_ASSERT(ggml_is_contiguous(src0));
1189
+
1159
1190
  case GGML_UNARY_OP_TANH:
1160
1191
  {
1161
1192
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
1182
1213
  } break;
1183
1214
  case GGML_UNARY_OP_GELU:
1184
1215
  {
1185
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1216
+ int64_t n = ggml_nelements(dst);
1217
+
1218
+ id<MTLComputePipelineState> pipeline = nil;
1219
+
1220
+ if (n % 4 == 0) {
1221
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1222
+ n /= 4;
1223
+ } else {
1224
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1225
+ }
1186
1226
 
1187
1227
  [encoder setComputePipelineState:pipeline];
1188
1228
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1189
1229
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1190
1230
 
1191
- const int64_t n = ggml_nelements(dst);
1192
- GGML_ASSERT(n % 4 == 0);
1193
-
1194
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1231
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1232
  } break;
1196
1233
  case GGML_UNARY_OP_GELU_QUICK:
1197
1234
  {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1235
+ int64_t n = ggml_nelements(dst);
1236
+
1237
+ id<MTLComputePipelineState> pipeline = nil;
1238
+
1239
+ if (n % 4 == 0) {
1240
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1241
+ n /= 4;
1242
+ } else {
1243
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1244
+ }
1199
1245
 
1200
1246
  [encoder setComputePipelineState:pipeline];
1201
1247
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1202
1248
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1203
1249
 
1204
- const int64_t n = ggml_nelements(dst);
1205
- GGML_ASSERT(n % 4 == 0);
1206
-
1207
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1208
1251
  } break;
1209
1252
  case GGML_UNARY_OP_SILU:
1210
1253
  {
1211
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1254
+ int64_t n = ggml_nelements(dst);
1255
+
1256
+ id<MTLComputePipelineState> pipeline = nil;
1257
+
1258
+ if (n % 4 == 0) {
1259
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1260
+ n /= 4;
1261
+ } else {
1262
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1263
+ }
1212
1264
 
1213
1265
  [encoder setComputePipelineState:pipeline];
1214
1266
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1215
1267
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1216
1268
 
1217
- const int64_t n = ggml_nelements(dst);
1218
- GGML_ASSERT(n % 4 == 0);
1219
-
1220
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
1270
  } break;
1222
1271
  default:
1223
1272
  {
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
1683
1732
  } break;
1684
1733
  case GGML_OP_MUL_MAT_ID:
1685
1734
  {
1686
- //GGML_ASSERT(ne00 == ne10);
1687
- //GGML_ASSERT(ne03 == ne13);
1688
1735
  const int n_as = src0->ne[2];
1689
1736
 
1690
- // max size of the src1ids array in the kernel shared buffer
1691
- GGML_ASSERT(ne11 <= 4096);
1692
-
1693
1737
  // src2 = ids
1694
- const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1738
+ const int64_t ne20 = src2->ne[0];
1695
1739
  const int64_t ne21 = src2->ne[1];
1696
1740
  const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
1741
  const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
1712
1756
 
1713
1757
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1714
1758
  // to the matrix-vector kernel
1715
- int ne11_mm_min = n_as;
1716
-
1717
- const int idx = ((int32_t *) dst->op_params)[0];
1759
+ // ne20 = n_used_experts
1760
+ // ne21 = n_rows
1761
+ const int dst_rows = ne20*ne21;
1762
+ const int dst_rows_min = n_as;
1718
1763
 
1719
- // batch size
1720
- GGML_ASSERT(ne21 == ne11); // ?
1721
- GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
- const uint r2 = 1;
1723
- const uint r3 = 1;
1764
+ // max size of the rowids array in the kernel shared buffer
1765
+ GGML_ASSERT(dst_rows <= 2048);
1724
1766
 
1725
1767
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1726
1768
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
1730
1772
  // !!!
1731
1773
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1732
1774
  ne00 % 32 == 0 && ne00 >= 64 &&
1733
- ne11 > ne11_mm_min) {
1775
+ dst_rows > dst_rows_min) {
1734
1776
 
1735
1777
  // some Metal matrix data types require aligned pointers
1736
1778
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
1772
1814
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1773
1815
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1774
1816
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
- [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1791
-
1792
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1793
-
1794
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1817
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1818
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1819
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1820
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1821
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1822
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1823
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1824
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1825
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1826
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1827
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1828
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1829
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1830
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1831
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1832
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1833
+
1834
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1835
+
1836
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1795
1837
  } else {
1796
1838
  int nth0 = 32;
1797
1839
  int nth1 = 1;
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
1926
1968
  {
1927
1969
  nth0 = 4;
1928
1970
  nth1 = 16;
1971
+ #if QK_K == 64
1972
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1973
+ #else
1929
1974
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1975
+ #endif
1976
+
1930
1977
  } break;
1931
1978
  default:
1932
1979
  {
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
1939
1986
  GGML_ASSERT(ne00 >= nth0*nth1);
1940
1987
  }
1941
1988
 
1942
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1943
-
1944
1989
  [encoder setComputePipelineState:pipeline];
1945
1990
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1946
1991
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1947
1992
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1948
1993
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
- [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1994
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1995
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1996
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1997
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1998
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
1999
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2000
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2001
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2002
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2003
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2004
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2005
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2006
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2007
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2008
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2009
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2010
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2011
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2012
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2013
+
2014
+ const int64_t _ne1 = 1;
2015
+ const int tgz = dst_rows;
1969
2016
 
1970
2017
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
2018
  src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
2019
  src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2020
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
2021
  }
1975
2022
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
2023
  const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1977
2024
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1978
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2025
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1979
2026
  }
1980
2027
  else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
2028
  const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1982
2029
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1983
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2030
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1984
2031
  }
1985
2032
  else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1986
2033
  const int mem_size = 32*sizeof(float);
1987
2034
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1988
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
2036
  }
1990
2037
  else if (src0t == GGML_TYPE_Q4_K) {
1991
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1992
2039
  }
1993
2040
  else if (src0t == GGML_TYPE_Q3_K) {
1994
2041
  #ifdef GGML_QKK_64
1995
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2042
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1996
2043
  #else
1997
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2044
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
2045
  #endif
1999
2046
  }
2000
2047
  else if (src0t == GGML_TYPE_Q5_K) {
2001
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2002
2049
  }
2003
2050
  else if (src0t == GGML_TYPE_Q6_K) {
2004
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2051
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2005
2052
  } else {
2006
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2007
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2053
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2054
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2008
2055
  }
2009
2056
  }
2010
2057
  } break;