llama_cpp 0.12.6 → 0.12.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -53,11 +53,23 @@ extern "C" {
53
53
  //
54
54
  #include <arm_neon.h>
55
55
 
56
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
57
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
56
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
57
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
58
+
59
+ #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
60
+
61
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
62
+ __fp16 tmp;
63
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
64
+ return (float)tmp;
65
+ }
58
66
 
59
- #define GGML_FP16_TO_FP32(x) ((float) (x))
60
- #define GGML_FP32_TO_FP16(x) (x)
67
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
68
+ ggml_fp16_t res;
69
+ __fp16 tmp = f;
70
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
71
+ return res;
72
+ }
61
73
 
62
74
  #else
63
75
 
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
214
226
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
215
227
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
216
228
  // This is also true for POWER9.
217
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
218
-
229
+ #if !defined(GGML_FP16_TO_FP32)
219
230
  inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
220
231
  uint16_t s;
221
232
  memcpy(&s, &f, sizeof(uint16_t));
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
223
234
  }
224
235
 
225
236
  #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
226
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
237
+ #endif
227
238
 
239
+ #if !defined(GGML_FP32_TO_FP16)
240
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
228
241
  #endif
229
242
 
230
243
  #define GGML_HASHTABLE_FULL ((size_t)-1)
@@ -61,6 +61,8 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
67
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
68
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +85,8 @@ enum ggml_metal_kernel_type {
83
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
86
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
91
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +105,8 @@ enum ggml_metal_kernel_type {
101
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
104
110
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +122,8 @@ enum ggml_metal_kernel_type {
116
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
119
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
128
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
129
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +139,8 @@ enum ggml_metal_kernel_type {
131
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
142
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
134
144
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
145
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
146
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -176,7 +186,7 @@ struct ggml_metal_context {
176
186
  // MSL code
177
187
  // TODO: move the contents here when ready
178
188
  // for now it is easier to work in a separate file
179
- //static NSString * const msl_library_source = @"see metal.metal";
189
+ // static NSString * const msl_library_source = @"see metal.metal";
180
190
 
181
191
  // Here to assist with NSBundle Path Hack
182
192
  @interface GGMLMetalClass : NSObject
@@ -272,6 +282,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
282
  return NULL;
273
283
  }
274
284
  } else {
285
+ #if GGML_METAL_EMBED_LIBRARY
286
+ GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
287
+
288
+ extern const char ggml_metallib_start[];
289
+ extern const char ggml_metallib_end[];
290
+
291
+ NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
292
+ #else
275
293
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
276
294
 
277
295
  NSString * sourcePath;
@@ -294,6 +312,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
294
312
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
295
313
  return NULL;
296
314
  }
315
+ #endif
297
316
 
298
317
  @autoreleasepool {
299
318
  // dictionary of preprocessor macros
@@ -433,6 +452,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
433
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
436
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +476,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
455
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
458
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459
482
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +496,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
473
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
476
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +513,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
488
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
491
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +530,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
503
530
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504
531
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
532
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
533
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
506
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -728,6 +757,7 @@ static bool ggml_metal_graph_compute(
728
757
 
729
758
  size_t offs_src0 = 0;
730
759
  size_t offs_src1 = 0;
760
+ size_t offs_src2 = 0;
731
761
  size_t offs_dst = 0;
732
762
 
733
763
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -746,6 +776,7 @@ static bool ggml_metal_graph_compute(
746
776
 
747
777
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
748
778
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
779
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
749
780
  struct ggml_tensor * dst = gf->nodes[i];
750
781
 
751
782
  switch (dst->op) {
@@ -807,6 +838,7 @@ static bool ggml_metal_graph_compute(
807
838
 
808
839
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
809
840
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
841
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
810
842
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
811
843
 
812
844
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1188,7 +1220,16 @@ static bool ggml_metal_graph_compute(
1188
1220
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1189
1221
  }
1190
1222
 
1191
- const float scale = ((float *) dst->op_params)[0];
1223
+ const float scale = ((float *) dst->op_params)[0];
1224
+ const float max_bias = ((float *) dst->op_params)[1];
1225
+
1226
+ const int64_t nrows_x = ggml_nrows(src0);
1227
+ const int64_t nrows_y = src0->ne[1];
1228
+ const uint32_t n_head_kv = nrows_x/nrows_y;
1229
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1230
+
1231
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1232
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1192
1233
 
1193
1234
  [encoder setComputePipelineState:pipeline];
1194
1235
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1197,11 +1238,20 @@ static bool ggml_metal_graph_compute(
1197
1238
  } else {
1198
1239
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1199
1240
  }
1200
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1201
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1202
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1203
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1204
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1241
+ if (id_src2) {
1242
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1243
+ } else {
1244
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1245
+ }
1246
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1247
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1248
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1249
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1250
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1251
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1252
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1253
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1254
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1205
1255
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1206
1256
 
1207
1257
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1297,6 +1347,8 @@ static bool ggml_metal_graph_compute(
1297
1347
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1298
1348
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1299
1349
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1350
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1300
1352
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1301
1353
  }
1302
1354
 
@@ -1431,6 +1483,18 @@ static bool ggml_metal_graph_compute(
1431
1483
  nth1 = 16;
1432
1484
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1433
1485
  } break;
1486
+ case GGML_TYPE_IQ1_S:
1487
+ {
1488
+ nth0 = 4;
1489
+ nth1 = 16;
1490
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1491
+ } break;
1492
+ case GGML_TYPE_IQ4_NL:
1493
+ {
1494
+ nth0 = 4;
1495
+ nth1 = 16;
1496
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
+ } break;
1434
1498
  default:
1435
1499
  {
1436
1500
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1465,7 +1529,7 @@ static bool ggml_metal_graph_compute(
1465
1529
 
1466
1530
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1467
1531
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1468
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1532
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1469
1533
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1470
1534
  }
1471
1535
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1478,6 +1542,11 @@ static bool ggml_metal_graph_compute(
1478
1542
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
1543
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
1544
  }
1545
+ else if (src0t == GGML_TYPE_IQ4_NL) {
1546
+ const int mem_size = 32*sizeof(float);
1547
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1549
+ }
1481
1550
  else if (src0t == GGML_TYPE_Q4_K) {
1482
1551
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
1552
  }
@@ -1514,8 +1583,6 @@ static bool ggml_metal_graph_compute(
1514
1583
  // max size of the src1ids array in the kernel stack
1515
1584
  GGML_ASSERT(ne11 <= 512);
1516
1585
 
1517
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1518
-
1519
1586
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
1587
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
1588
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
@@ -1573,6 +1640,8 @@ static bool ggml_metal_graph_compute(
1573
1640
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
1641
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
1642
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1643
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1576
1645
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1577
1646
  }
1578
1647
 
@@ -1710,6 +1779,18 @@ static bool ggml_metal_graph_compute(
1710
1779
  nth1 = 16;
1711
1780
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1712
1781
  } break;
1782
+ case GGML_TYPE_IQ1_S:
1783
+ {
1784
+ nth0 = 4;
1785
+ nth1 = 16;
1786
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1787
+ } break;
1788
+ case GGML_TYPE_IQ4_NL:
1789
+ {
1790
+ nth0 = 4;
1791
+ nth1 = 16;
1792
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
+ } break;
1713
1794
  default:
1714
1795
  {
1715
1796
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1760,7 +1841,7 @@ static bool ggml_metal_graph_compute(
1760
1841
 
1761
1842
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1762
1843
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1763
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1844
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1764
1845
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1765
1846
  }
1766
1847
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1773,6 +1854,11 @@ static bool ggml_metal_graph_compute(
1773
1854
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1774
1855
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1775
1856
  }
1857
+ else if (src2t == GGML_TYPE_IQ4_NL) {
1858
+ const int mem_size = 32*sizeof(float);
1859
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1861
+ }
1776
1862
  else if (src2t == GGML_TYPE_Q4_K) {
1777
1863
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1778
1864
  }
@@ -1814,6 +1900,8 @@ static bool ggml_metal_graph_compute(
1814
1900
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1815
1901
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1816
1902
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1903
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1817
1905
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1818
1906
  default: GGML_ASSERT(false && "not implemented");
1819
1907
  }