llama_cpp 0.12.6 → 0.12.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,11 +53,23 @@ extern "C" {
53
53
  //
54
54
  #include <arm_neon.h>
55
55
 
56
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
57
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
56
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
57
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
58
+
59
+ #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
60
+
61
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
62
+ __fp16 tmp;
63
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
64
+ return (float)tmp;
65
+ }
58
66
 
59
- #define GGML_FP16_TO_FP32(x) ((float) (x))
60
- #define GGML_FP32_TO_FP16(x) (x)
67
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
68
+ ggml_fp16_t res;
69
+ __fp16 tmp = f;
70
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
71
+ return res;
72
+ }
61
73
 
62
74
  #else
63
75
 
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
214
226
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
215
227
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
216
228
  // This is also true for POWER9.
217
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
218
-
229
+ #if !defined(GGML_FP16_TO_FP32)
219
230
  inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
220
231
  uint16_t s;
221
232
  memcpy(&s, &f, sizeof(uint16_t));
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
223
234
  }
224
235
 
225
236
  #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
226
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
237
+ #endif
227
238
 
239
+ #if !defined(GGML_FP32_TO_FP16)
240
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
228
241
  #endif
229
242
 
230
243
  #define GGML_HASHTABLE_FULL ((size_t)-1)
@@ -61,6 +61,8 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
67
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
68
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +85,8 @@ enum ggml_metal_kernel_type {
83
85
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
86
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
87
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
88
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
86
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
91
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +105,8 @@ enum ggml_metal_kernel_type {
101
105
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
106
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
107
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
108
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
104
110
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +122,8 @@ enum ggml_metal_kernel_type {
116
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
123
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
124
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
125
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
119
127
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
128
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
129
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +139,8 @@ enum ggml_metal_kernel_type {
131
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
142
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
134
144
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
145
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
146
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -176,7 +186,7 @@ struct ggml_metal_context {
176
186
  // MSL code
177
187
  // TODO: move the contents here when ready
178
188
  // for now it is easier to work in a separate file
179
- //static NSString * const msl_library_source = @"see metal.metal";
189
+ // static NSString * const msl_library_source = @"see metal.metal";
180
190
 
181
191
  // Here to assist with NSBundle Path Hack
182
192
  @interface GGMLMetalClass : NSObject
@@ -272,6 +282,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
282
  return NULL;
273
283
  }
274
284
  } else {
285
+ #if GGML_METAL_EMBED_LIBRARY
286
+ GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
287
+
288
+ extern const char ggml_metallib_start[];
289
+ extern const char ggml_metallib_end[];
290
+
291
+ NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
292
+ #else
275
293
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
276
294
 
277
295
  NSString * sourcePath;
@@ -294,6 +312,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
294
312
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
295
313
  return NULL;
296
314
  }
315
+ #endif
297
316
 
298
317
  @autoreleasepool {
299
318
  // dictionary of preprocessor macros
@@ -433,6 +452,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
433
452
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434
453
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
454
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
455
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
436
457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437
458
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438
459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +476,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
455
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
458
481
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459
482
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460
483
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +496,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
473
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
476
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +513,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
488
513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489
514
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
515
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
491
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +530,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
503
530
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504
531
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
532
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
533
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
506
535
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507
536
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -728,6 +757,7 @@ static bool ggml_metal_graph_compute(
728
757
 
729
758
  size_t offs_src0 = 0;
730
759
  size_t offs_src1 = 0;
760
+ size_t offs_src2 = 0;
731
761
  size_t offs_dst = 0;
732
762
 
733
763
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -746,6 +776,7 @@ static bool ggml_metal_graph_compute(
746
776
 
747
777
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
748
778
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
779
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
749
780
  struct ggml_tensor * dst = gf->nodes[i];
750
781
 
751
782
  switch (dst->op) {
@@ -807,6 +838,7 @@ static bool ggml_metal_graph_compute(
807
838
 
808
839
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
809
840
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
841
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
810
842
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
811
843
 
812
844
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1188,7 +1220,16 @@ static bool ggml_metal_graph_compute(
1188
1220
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1189
1221
  }
1190
1222
 
1191
- const float scale = ((float *) dst->op_params)[0];
1223
+ const float scale = ((float *) dst->op_params)[0];
1224
+ const float max_bias = ((float *) dst->op_params)[1];
1225
+
1226
+ const int64_t nrows_x = ggml_nrows(src0);
1227
+ const int64_t nrows_y = src0->ne[1];
1228
+ const uint32_t n_head_kv = nrows_x/nrows_y;
1229
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1230
+
1231
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1232
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1192
1233
 
1193
1234
  [encoder setComputePipelineState:pipeline];
1194
1235
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1197,11 +1238,20 @@ static bool ggml_metal_graph_compute(
1197
1238
  } else {
1198
1239
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1199
1240
  }
1200
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1201
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1202
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1203
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1204
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1241
+ if (id_src2) {
1242
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1243
+ } else {
1244
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1245
+ }
1246
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1247
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1248
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1249
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1250
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1251
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1252
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1253
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1254
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1205
1255
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1206
1256
 
1207
1257
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1297,6 +1347,8 @@ static bool ggml_metal_graph_compute(
1297
1347
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1298
1348
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1299
1349
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1350
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1300
1352
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1301
1353
  }
1302
1354
 
@@ -1431,6 +1483,18 @@ static bool ggml_metal_graph_compute(
1431
1483
  nth1 = 16;
1432
1484
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1433
1485
  } break;
1486
+ case GGML_TYPE_IQ1_S:
1487
+ {
1488
+ nth0 = 4;
1489
+ nth1 = 16;
1490
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1491
+ } break;
1492
+ case GGML_TYPE_IQ4_NL:
1493
+ {
1494
+ nth0 = 4;
1495
+ nth1 = 16;
1496
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
+ } break;
1434
1498
  default:
1435
1499
  {
1436
1500
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1465,7 +1529,7 @@ static bool ggml_metal_graph_compute(
1465
1529
 
1466
1530
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1467
1531
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1468
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1532
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1469
1533
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1470
1534
  }
1471
1535
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1478,6 +1542,11 @@ static bool ggml_metal_graph_compute(
1478
1542
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
1543
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
1544
  }
1545
+ else if (src0t == GGML_TYPE_IQ4_NL) {
1546
+ const int mem_size = 32*sizeof(float);
1547
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1549
+ }
1481
1550
  else if (src0t == GGML_TYPE_Q4_K) {
1482
1551
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
1552
  }
@@ -1514,8 +1583,6 @@ static bool ggml_metal_graph_compute(
1514
1583
  // max size of the src1ids array in the kernel stack
1515
1584
  GGML_ASSERT(ne11 <= 512);
1516
1585
 
1517
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1518
-
1519
1586
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
1587
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
1588
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
@@ -1573,6 +1640,8 @@ static bool ggml_metal_graph_compute(
1573
1640
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
1641
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
1642
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1643
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1576
1645
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1577
1646
  }
1578
1647
 
@@ -1710,6 +1779,18 @@ static bool ggml_metal_graph_compute(
1710
1779
  nth1 = 16;
1711
1780
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1712
1781
  } break;
1782
+ case GGML_TYPE_IQ1_S:
1783
+ {
1784
+ nth0 = 4;
1785
+ nth1 = 16;
1786
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1787
+ } break;
1788
+ case GGML_TYPE_IQ4_NL:
1789
+ {
1790
+ nth0 = 4;
1791
+ nth1 = 16;
1792
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
+ } break;
1713
1794
  default:
1714
1795
  {
1715
1796
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1760,7 +1841,7 @@ static bool ggml_metal_graph_compute(
1760
1841
 
1761
1842
  if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1762
1843
  src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1763
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1844
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1764
1845
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1765
1846
  }
1766
1847
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1773,6 +1854,11 @@ static bool ggml_metal_graph_compute(
1773
1854
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1774
1855
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1775
1856
  }
1857
+ else if (src2t == GGML_TYPE_IQ4_NL) {
1858
+ const int mem_size = 32*sizeof(float);
1859
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1861
+ }
1776
1862
  else if (src2t == GGML_TYPE_Q4_K) {
1777
1863
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1778
1864
  }
@@ -1814,6 +1900,8 @@ static bool ggml_metal_graph_compute(
1814
1900
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1815
1901
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1816
1902
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1903
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1817
1905
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1818
1906
  default: GGML_ASSERT(false && "not implemented");
1819
1907
  }