llama_cpp 0.12.6 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,11 +53,23 @@ extern "C" {
53
53
  //
54
54
  #include <arm_neon.h>
55
55
 
56
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
57
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
56
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
57
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
58
+
59
+ #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
60
+
61
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
62
+ __fp16 tmp;
63
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
64
+ return (float)tmp;
65
+ }
58
66
 
59
- #define GGML_FP16_TO_FP32(x) ((float) (x))
60
- #define GGML_FP32_TO_FP16(x) (x)
67
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
68
+ ggml_fp16_t res;
69
+ __fp16 tmp = f;
70
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
71
+ return res;
72
+ }
61
73
 
62
74
  #else
63
75
 
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
214
226
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
215
227
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
216
228
  // This is also true for POWER9.
217
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
218
-
229
+ #if !defined(GGML_FP16_TO_FP32)
219
230
  inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
220
231
  uint16_t s;
221
232
  memcpy(&s, &f, sizeof(uint16_t));
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
223
234
  }
224
235
 
225
236
  #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
226
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
237
+ #endif
227
238
 
239
+ #if !defined(GGML_FP32_TO_FP16)
240
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
228
241
  #endif
229
242
 
230
243
  #define GGML_HASHTABLE_FULL ((size_t)-1)
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
1954
  };
1955
1955
 
1956
+ static ggml_guid_t ggml_backend_kompute_guid() {
1957
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
1958
+ return &guid;
1959
+ }
1960
+
1956
1961
  ggml_backend_t ggml_backend_kompute_init(int device) {
1957
1962
  GGML_ASSERT(s_kompute_context == nullptr);
1958
1963
  s_kompute_context = new ggml_kompute_context(device);
1959
1964
 
1960
1965
  ggml_backend_t kompute_backend = new ggml_backend {
1966
+ /* .guid = */ ggml_backend_kompute_guid(),
1961
1967
  /* .interface = */ kompute_backend_i,
1962
1968
  /* .context = */ s_kompute_context,
1963
1969
  };
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
1966
1972
  }
1967
1973
 
1968
1974
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
1969
- return backend && backend->iface.get_name == ggml_backend_kompute_name;
1975
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
1970
1976
  }
1971
1977
 
1972
1978
  static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
@@ -61,6 +61,11 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
64
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
65
70
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
66
71
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +88,11 @@ enum ggml_metal_kernel_type {
83
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
84
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
85
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
94
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
86
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
87
97
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
88
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +111,11 @@ enum ggml_metal_kernel_type {
101
111
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102
112
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103
113
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
117
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
104
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +131,11 @@ enum ggml_metal_kernel_type {
116
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
137
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
119
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +151,11 @@ enum ggml_metal_kernel_type {
131
151
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132
152
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133
153
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
157
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
134
159
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
135
160
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
136
161
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -176,7 +201,7 @@ struct ggml_metal_context {
176
201
  // MSL code
177
202
  // TODO: move the contents here when ready
178
203
  // for now it is easier to work in a separate file
179
- //static NSString * const msl_library_source = @"see metal.metal";
204
+ // static NSString * const msl_library_source = @"see metal.metal";
180
205
 
181
206
  // Here to assist with NSBundle Path Hack
182
207
  @interface GGMLMetalClass : NSObject
@@ -272,6 +297,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
297
  return NULL;
273
298
  }
274
299
  } else {
300
+ #if GGML_METAL_EMBED_LIBRARY
301
+ GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
302
+
303
+ extern const char ggml_metallib_start[];
304
+ extern const char ggml_metallib_end[];
305
+
306
+ NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
307
+ #else
275
308
  GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
276
309
 
277
310
  NSString * sourcePath;
@@ -294,6 +327,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
294
327
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
295
328
  return NULL;
296
329
  }
330
+ #endif
297
331
 
298
332
  @autoreleasepool {
299
333
  // dictionary of preprocessor macros
@@ -433,6 +467,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
433
467
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434
468
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
470
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
471
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
436
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +494,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
455
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
458
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459
503
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +517,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
473
517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
522
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
523
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
476
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +537,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
488
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489
538
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490
539
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
540
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
541
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
542
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
543
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
544
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
491
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492
546
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +557,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
503
557
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504
558
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505
559
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
560
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
561
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
562
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
563
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
564
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
506
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507
566
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508
567
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -728,6 +787,7 @@ static bool ggml_metal_graph_compute(
728
787
 
729
788
  size_t offs_src0 = 0;
730
789
  size_t offs_src1 = 0;
790
+ size_t offs_src2 = 0;
731
791
  size_t offs_dst = 0;
732
792
 
733
793
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -746,6 +806,7 @@ static bool ggml_metal_graph_compute(
746
806
 
747
807
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
748
808
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
809
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
749
810
  struct ggml_tensor * dst = gf->nodes[i];
750
811
 
751
812
  switch (dst->op) {
@@ -807,6 +868,7 @@ static bool ggml_metal_graph_compute(
807
868
 
808
869
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
809
870
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
871
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
810
872
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
811
873
 
812
874
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1188,7 +1250,16 @@ static bool ggml_metal_graph_compute(
1188
1250
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1189
1251
  }
1190
1252
 
1191
- const float scale = ((float *) dst->op_params)[0];
1253
+ const float scale = ((float *) dst->op_params)[0];
1254
+ const float max_bias = ((float *) dst->op_params)[1];
1255
+
1256
+ const int64_t nrows_x = ggml_nrows(src0);
1257
+ const int64_t nrows_y = src0->ne[1];
1258
+ const uint32_t n_head_kv = nrows_x/nrows_y;
1259
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1260
+
1261
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1262
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1192
1263
 
1193
1264
  [encoder setComputePipelineState:pipeline];
1194
1265
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1197,11 +1268,20 @@ static bool ggml_metal_graph_compute(
1197
1268
  } else {
1198
1269
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1199
1270
  }
1200
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1201
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1202
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1203
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1204
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1271
+ if (id_src2) {
1272
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1273
+ } else {
1274
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1275
+ }
1276
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1277
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1278
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1279
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1280
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1281
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1282
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1283
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1284
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1205
1285
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1206
1286
 
1207
1287
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1297,6 +1377,11 @@ static bool ggml_metal_graph_compute(
1297
1377
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1298
1378
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1299
1379
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1380
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1381
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1382
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1383
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1384
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1300
1385
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1301
1386
  }
1302
1387
 
@@ -1431,6 +1516,36 @@ static bool ggml_metal_graph_compute(
1431
1516
  nth1 = 16;
1432
1517
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1433
1518
  } break;
1519
+ case GGML_TYPE_IQ3_S:
1520
+ {
1521
+ nth0 = 4;
1522
+ nth1 = 16;
1523
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1524
+ } break;
1525
+ case GGML_TYPE_IQ2_S:
1526
+ {
1527
+ nth0 = 4;
1528
+ nth1 = 16;
1529
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1530
+ } break;
1531
+ case GGML_TYPE_IQ1_S:
1532
+ {
1533
+ nth0 = 4;
1534
+ nth1 = 16;
1535
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1536
+ } break;
1537
+ case GGML_TYPE_IQ4_NL:
1538
+ {
1539
+ nth0 = 4;
1540
+ nth1 = 16;
1541
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1542
+ } break;
1543
+ case GGML_TYPE_IQ4_XS:
1544
+ {
1545
+ nth0 = 4;
1546
+ nth1 = 16;
1547
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1548
+ } break;
1434
1549
  default:
1435
1550
  {
1436
1551
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1463,9 +1578,9 @@ static bool ggml_metal_graph_compute(
1463
1578
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1464
1579
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1465
1580
 
1466
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1467
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1468
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1581
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1582
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1583
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1469
1584
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1470
1585
  }
1471
1586
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1473,11 +1588,16 @@ static bool ggml_metal_graph_compute(
1473
1588
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1474
1589
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1475
1590
  }
1476
- else if (src0t == GGML_TYPE_IQ3_XXS) {
1477
- const int mem_size = 256*4+128;
1591
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1592
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1478
1593
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1479
1594
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1480
1595
  }
1596
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1597
+ const int mem_size = 32*sizeof(float);
1598
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1599
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1600
+ }
1481
1601
  else if (src0t == GGML_TYPE_Q4_K) {
1482
1602
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1483
1603
  }
@@ -1514,8 +1634,6 @@ static bool ggml_metal_graph_compute(
1514
1634
  // max size of the src1ids array in the kernel stack
1515
1635
  GGML_ASSERT(ne11 <= 512);
1516
1636
 
1517
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1518
-
1519
1637
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1520
1638
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1521
1639
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
@@ -1573,6 +1691,11 @@ static bool ggml_metal_graph_compute(
1573
1691
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1574
1692
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1575
1693
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1694
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1695
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1696
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1697
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1698
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1576
1699
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1577
1700
  }
1578
1701
 
@@ -1710,6 +1833,36 @@ static bool ggml_metal_graph_compute(
1710
1833
  nth1 = 16;
1711
1834
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1712
1835
  } break;
1836
+ case GGML_TYPE_IQ3_S:
1837
+ {
1838
+ nth0 = 4;
1839
+ nth1 = 16;
1840
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
1841
+ } break;
1842
+ case GGML_TYPE_IQ2_S:
1843
+ {
1844
+ nth0 = 4;
1845
+ nth1 = 16;
1846
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
1847
+ } break;
1848
+ case GGML_TYPE_IQ1_S:
1849
+ {
1850
+ nth0 = 4;
1851
+ nth1 = 16;
1852
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1853
+ } break;
1854
+ case GGML_TYPE_IQ4_NL:
1855
+ {
1856
+ nth0 = 4;
1857
+ nth1 = 16;
1858
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1859
+ } break;
1860
+ case GGML_TYPE_IQ4_XS:
1861
+ {
1862
+ nth0 = 4;
1863
+ nth1 = 16;
1864
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1865
+ } break;
1713
1866
  default:
1714
1867
  {
1715
1868
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1758,9 +1911,9 @@ static bool ggml_metal_graph_compute(
1758
1911
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1759
1912
  }
1760
1913
 
1761
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1762
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1763
- src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1914
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1915
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1916
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1764
1917
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1765
1918
  }
1766
1919
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1768,11 +1921,16 @@ static bool ggml_metal_graph_compute(
1768
1921
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1769
1922
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
1923
  }
1771
- else if (src2t == GGML_TYPE_IQ3_XXS) {
1772
- const int mem_size = 256*4+128;
1924
+ else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1925
+ const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1773
1926
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1774
1927
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1775
1928
  }
1929
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1930
+ const int mem_size = 32*sizeof(float);
1931
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1932
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1933
+ }
1776
1934
  else if (src2t == GGML_TYPE_Q4_K) {
1777
1935
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1778
1936
  }
@@ -1814,6 +1972,11 @@ static bool ggml_metal_graph_compute(
1814
1972
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1815
1973
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1816
1974
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1975
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
1976
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1977
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1978
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1979
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
1817
1980
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1818
1981
  default: GGML_ASSERT(false && "not implemented");
1819
1982
  }
@@ -2149,8 +2312,8 @@ static bool ggml_metal_graph_compute(
2149
2312
  id<MTLComputePipelineState> pipeline = nil;
2150
2313
 
2151
2314
  switch (order) {
2152
- case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2153
- case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2315
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2316
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2154
2317
  default: GGML_ASSERT(false);
2155
2318
  };
2156
2319
 
@@ -2608,6 +2771,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
2608
2771
  ggml_metal_log_user_data = user_data;
2609
2772
  }
2610
2773
 
2774
+ static ggml_guid_t ggml_backend_metal_guid(void) {
2775
+ static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
2776
+ return &guid;
2777
+ }
2778
+
2611
2779
  ggml_backend_t ggml_backend_metal_init(void) {
2612
2780
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2613
2781
 
@@ -2618,6 +2786,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2618
2786
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2619
2787
 
2620
2788
  *metal_backend = (struct ggml_backend) {
2789
+ /* .guid = */ ggml_backend_metal_guid(),
2621
2790
  /* .interface = */ ggml_backend_metal_i,
2622
2791
  /* .context = */ ctx,
2623
2792
  };
@@ -2626,7 +2795,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2626
2795
  }
2627
2796
 
2628
2797
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2629
- return backend && backend->iface.get_name == ggml_backend_metal_name;
2798
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
2630
2799
  }
2631
2800
 
2632
2801
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {