llama_cpp 0.12.5 → 0.12.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,11 +53,23 @@ extern "C" {
53
53
  //
54
54
  #include <arm_neon.h>
55
55
 
56
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
57
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
56
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
57
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
58
+
59
+ #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
60
+
61
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
62
+ __fp16 tmp;
63
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
64
+ return (float)tmp;
65
+ }
58
66
 
59
- #define GGML_FP16_TO_FP32(x) ((float) (x))
60
- #define GGML_FP32_TO_FP16(x) (x)
67
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
68
+ ggml_fp16_t res;
69
+ __fp16 tmp = f;
70
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
71
+ return res;
72
+ }
61
73
 
62
74
  #else
63
75
 
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
214
226
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
215
227
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
216
228
  // This is also true for POWER9.
217
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
218
-
229
+ #if !defined(GGML_FP16_TO_FP32)
219
230
  inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
220
231
  uint16_t s;
221
232
  memcpy(&s, &f, sizeof(uint16_t));
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
223
234
  }
224
235
 
225
236
  #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
226
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
237
+ #endif
227
238
 
239
+ #if !defined(GGML_FP32_TO_FP16)
240
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
228
241
  #endif
229
242
 
230
243
  #define GGML_HASHTABLE_FULL ((size_t)-1)
@@ -61,6 +61,8 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
67
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
68
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +85,8 @@ enum ggml_metal_kernel_type {
83
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
86
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
91
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +105,8 @@ enum ggml_metal_kernel_type {
101
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
104
110
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +122,8 @@ enum ggml_metal_kernel_type {
116
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
119
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
128
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
129
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +139,8 @@ enum ggml_metal_kernel_type {
131
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
142
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
134
144
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
145
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
146
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -176,7 +186,7 @@ struct ggml_metal_context {
176
186
  // MSL code
177
187
  // TODO: move the contents here when ready
178
188
  // for now it is easier to work in a separate file
179
- //static NSString * const msl_library_source = @"see metal.metal";
189
+ // static NSString * const msl_library_source = @"see metal.metal";
180
190
 
181
191
  // Here to assist with NSBundle Path Hack
182
192
  @interface GGMLMetalClass : NSObject
@@ -272,6 +282,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
282
  return NULL;
273
283
  }
274
284
  } else {
285
+ #if GGML_METAL_EMBED_LIBRARY
286
+ GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
287
+
288
+ extern const char ggml_metallib_start[];
289
+ extern const char ggml_metallib_end[];
290
+
291
+ NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
292
+ #else
275
293
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
276
294
 
277
295
  NSString * sourcePath;
@@ -294,6 +312,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
294
312
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
295
313
  return NULL;
296
314
  }
315
+ #endif
297
316
 
298
317
  @autoreleasepool {
299
318
  // dictionary of preprocessor macros
@@ -433,6 +452,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
433
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
436
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +476,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
455
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
458
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459
482
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +496,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
473
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
476
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +513,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
488
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
491
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +530,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
503
530
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504
531
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
532
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
533
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
506
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -687,6 +716,7 @@ static bool ggml_metal_graph_compute(
687
716
  struct ggml_metal_context * ctx,
688
717
  struct ggml_cgraph * gf) {
689
718
 
719
+ @autoreleasepool {
690
720
  MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
691
721
  edesc.dispatchType = MTLDispatchTypeSerial;
692
722
 
@@ -727,6 +757,7 @@ static bool ggml_metal_graph_compute(
727
757
 
728
758
  size_t offs_src0 = 0;
729
759
  size_t offs_src1 = 0;
760
+ size_t offs_src2 = 0;
730
761
  size_t offs_dst = 0;
731
762
 
732
763
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -745,6 +776,7 @@ static bool ggml_metal_graph_compute(
745
776
 
746
777
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
747
778
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
779
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
748
780
  struct ggml_tensor * dst = gf->nodes[i];
749
781
 
750
782
  switch (dst->op) {
@@ -806,6 +838,7 @@ static bool ggml_metal_graph_compute(
806
838
 
807
839
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
808
840
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
841
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
809
842
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
810
843
 
811
844
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1187,7 +1220,16 @@ static bool ggml_metal_graph_compute(
1187
1220
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1188
1221
  }
1189
1222
 
1190
- const float scale = ((float *) dst->op_params)[0];
1223
+ const float scale = ((float *) dst->op_params)[0];
1224
+ const float max_bias = ((float *) dst->op_params)[1];
1225
+
1226
+ const int64_t nrows_x = ggml_nrows(src0);
1227
+ const int64_t nrows_y = src0->ne[1];
1228
+ const uint32_t n_head_kv = nrows_x/nrows_y;
1229
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1230
+
1231
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1232
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1191
1233
 
1192
1234
  [encoder setComputePipelineState:pipeline];
1193
1235
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1196,11 +1238,20 @@ static bool ggml_metal_graph_compute(
1196
1238
  } else {
1197
1239
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1198
1240
  }
1199
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1200
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1201
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1202
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1203
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1241
+ if (id_src2) {
1242
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1243
+ } else {
1244
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1245
+ }
1246
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1247
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1248
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1249
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1250
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1251
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1252
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1253
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1254
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1204
1255
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1205
1256
 
1206
1257
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1296,6 +1347,8 @@ static bool ggml_metal_graph_compute(
1296
1347
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1297
1348
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1298
1349
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1350
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1299
1352
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1300
1353
  }
1301
1354
 
@@ -1430,6 +1483,18 @@ static bool ggml_metal_graph_compute(
1430
1483
  nth1 = 16;
1431
1484
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1432
1485
  } break;
1486
+ case GGML_TYPE_IQ1_S:
1487
+ {
1488
+ nth0 = 4;
1489
+ nth1 = 16;
1490
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1491
+ } break;
1492
+ case GGML_TYPE_IQ4_NL:
1493
+ {
1494
+ nth0 = 4;
1495
+ nth1 = 16;
1496
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
+ } break;
1433
1498
  default:
1434
1499
  {
1435
1500
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1464,7 +1529,7 @@ static bool ggml_metal_graph_compute(
1464
1529
 
1465
1530
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1466
1531
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1467
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1532
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1468
1533
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1469
1534
  }
1470
1535
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1477,6 +1542,11 @@ static bool ggml_metal_graph_compute(
1477
1542
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1478
1543
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1479
1544
  }
1545
+ else if (src0t == GGML_TYPE_IQ4_NL) {
1546
+ const int mem_size = 32*sizeof(float);
1547
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1549
+ }
1480
1550
  else if (src0t == GGML_TYPE_Q4_K) {
1481
1551
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1482
1552
  }
@@ -1513,8 +1583,6 @@ static bool ggml_metal_graph_compute(
1513
1583
  // max size of the src1ids array in the kernel stack
1514
1584
  GGML_ASSERT(ne11 <= 512);
1515
1585
 
1516
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1517
-
1518
1586
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1519
1587
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1520
1588
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
@@ -1572,6 +1640,8 @@ static bool ggml_metal_graph_compute(
1572
1640
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1573
1641
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1574
1642
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1643
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1575
1645
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1576
1646
  }
1577
1647
 
@@ -1709,6 +1779,18 @@ static bool ggml_metal_graph_compute(
1709
1779
  nth1 = 16;
1710
1780
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1711
1781
  } break;
1782
+ case GGML_TYPE_IQ1_S:
1783
+ {
1784
+ nth0 = 4;
1785
+ nth1 = 16;
1786
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1787
+ } break;
1788
+ case GGML_TYPE_IQ4_NL:
1789
+ {
1790
+ nth0 = 4;
1791
+ nth1 = 16;
1792
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
+ } break;
1712
1794
  default:
1713
1795
  {
1714
1796
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1759,7 +1841,7 @@ static bool ggml_metal_graph_compute(
1759
1841
 
1760
1842
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1761
1843
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1762
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1844
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1763
1845
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1764
1846
  }
1765
1847
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1772,6 +1854,11 @@ static bool ggml_metal_graph_compute(
1772
1854
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1773
1855
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1774
1856
  }
1857
+ else if (src2t == GGML_TYPE_IQ4_NL) {
1858
+ const int mem_size = 32*sizeof(float);
1859
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1861
+ }
1775
1862
  else if (src2t == GGML_TYPE_Q4_K) {
1776
1863
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1777
1864
  }
@@ -1813,6 +1900,8 @@ static bool ggml_metal_graph_compute(
1813
1900
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1814
1901
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1815
1902
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1903
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1816
1905
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1817
1906
  default: GGML_ASSERT(false && "not implemented");
1818
1907
  }
@@ -2272,6 +2361,7 @@ static bool ggml_metal_graph_compute(
2272
2361
  [[MTLCaptureManager sharedCaptureManager] stopCapture];
2273
2362
  }
2274
2363
 
2364
+ }
2275
2365
  return true;
2276
2366
  }
2277
2367