llama_cpp 0.12.7 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,6 +104,8 @@ extern "C" {
104
104
  };
105
105
 
106
106
  struct ggml_backend {
107
+ ggml_guid_t guid;
108
+
107
109
  struct ggml_backend_i iface;
108
110
 
109
111
  ggml_backend_context_t context;
@@ -12,7 +12,6 @@
12
12
 
13
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
14
 
15
-
16
15
  // backend buffer type
17
16
 
18
17
  const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
@@ -159,6 +158,13 @@ bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml
159
158
 
160
159
  // backend
161
160
 
161
+ ggml_guid_t ggml_backend_guid(ggml_backend_t backend) {
162
+ if (backend == NULL) {
163
+ return NULL;
164
+ }
165
+ return backend->guid;
166
+ }
167
+
162
168
  const char * ggml_backend_name(ggml_backend_t backend) {
163
169
  if (backend == NULL) {
164
170
  return "NULL";
@@ -781,6 +787,11 @@ static struct ggml_backend_i cpu_backend_i = {
781
787
  /* .supports_op = */ ggml_backend_cpu_supports_op,
782
788
  };
783
789
 
790
+ static ggml_guid_t ggml_backend_cpu_guid(void) {
791
+ static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
792
+ return &guid;
793
+ }
794
+
784
795
  ggml_backend_t ggml_backend_cpu_init(void) {
785
796
  struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
786
797
  if (ctx == NULL) {
@@ -800,6 +811,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
800
811
  }
801
812
 
802
813
  *cpu_backend = (struct ggml_backend) {
814
+ /* .guid = */ ggml_backend_cpu_guid(),
803
815
  /* .interface = */ cpu_backend_i,
804
816
  /* .context = */ ctx
805
817
  };
@@ -807,7 +819,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
807
819
  }
808
820
 
809
821
  GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) {
810
- return backend && backend->iface.get_name == ggml_backend_cpu_name;
822
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
811
823
  }
812
824
 
813
825
  void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
@@ -49,7 +49,7 @@ extern "C" {
49
49
  // Backend
50
50
  //
51
51
 
52
-
52
+ GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
53
53
  GGML_API const char * ggml_backend_name(ggml_backend_t backend);
54
54
  GGML_API void ggml_backend_free(ggml_backend_t backend);
55
55
 
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
1954
  };
1955
1955
 
1956
+ static ggml_guid_t ggml_backend_kompute_guid() {
1957
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
1958
+ return &guid;
1959
+ }
1960
+
1956
1961
  ggml_backend_t ggml_backend_kompute_init(int device) {
1957
1962
  GGML_ASSERT(s_kompute_context == nullptr);
1958
1963
  s_kompute_context = new ggml_kompute_context(device);
1959
1964
 
1960
1965
  ggml_backend_t kompute_backend = new ggml_backend {
1966
+ /* .guid = */ ggml_backend_kompute_guid(),
1961
1967
  /* .interface = */ kompute_backend_i,
1962
1968
  /* .context = */ s_kompute_context,
1963
1969
  };
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
1966
1972
  }
1967
1973
 
1968
1974
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
1969
- return backend && backend->iface.get_name == ggml_backend_kompute_name;
1975
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
1970
1976
  }
1971
1977
 
1972
1978
  static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
@@ -61,8 +61,11 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
66
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
67
70
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
68
71
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,8 +88,11 @@ enum ggml_metal_kernel_type {
85
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
86
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
87
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
88
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
90
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
91
97
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
92
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -105,8 +111,11 @@ enum ggml_metal_kernel_type {
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
107
113
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
108
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
110
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
111
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
112
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -122,8 +131,11 @@ enum ggml_metal_kernel_type {
122
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
123
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
124
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
125
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
127
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -139,8 +151,11 @@ enum ggml_metal_kernel_type {
139
151
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
152
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
153
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
142
156
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
157
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
144
159
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
145
160
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
146
161
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -452,8 +467,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
452
467
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
453
468
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
454
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
470
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
471
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
455
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
474
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
457
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
458
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
459
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -476,8 +494,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
476
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
477
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
478
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
479
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
481
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
482
503
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
483
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -496,8 +517,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
496
517
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
497
518
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
498
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
499
522
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
501
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
502
526
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
503
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -513,8 +537,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
513
537
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
514
538
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
515
539
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
540
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
541
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
516
542
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
543
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
544
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
518
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
519
546
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
520
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -530,8 +557,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
530
557
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
531
558
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
532
559
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
560
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
561
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
533
562
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
563
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
564
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
535
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
536
566
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
537
567
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1347,8 +1377,11 @@ static bool ggml_metal_graph_compute(
1347
1377
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1348
1378
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1349
1379
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1380
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1381
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1350
1382
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
1383
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1384
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1352
1385
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1353
1386
  }
1354
1387
 
@@ -1483,6 +1516,18 @@ static bool ggml_metal_graph_compute(
1483
1516
  nth1 = 16;
1484
1517
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1485
1518
  } break;
1519
+ case GGML_TYPE_IQ3_S:
1520
+ {
1521
+ nth0 = 4;
1522
+ nth1 = 16;
1523
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1524
+ } break;
1525
+ case GGML_TYPE_IQ2_S:
1526
+ {
1527
+ nth0 = 4;
1528
+ nth1 = 16;
1529
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1530
+ } break;
1486
1531
  case GGML_TYPE_IQ1_S:
1487
1532
  {
1488
1533
  nth0 = 4;
@@ -1495,6 +1540,12 @@ static bool ggml_metal_graph_compute(
1495
1540
  nth1 = 16;
1496
1541
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
1542
  } break;
1543
+ case GGML_TYPE_IQ4_XS:
1544
+ {
1545
+ nth0 = 4;
1546
+ nth1 = 16;
1547
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1548
+ } break;
1498
1549
  default:
1499
1550
  {
1500
1551
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1527,9 +1578,9 @@ static bool ggml_metal_graph_compute(
1527
1578
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1528
1579
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1529
1580
 
1530
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1531
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1532
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1581
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1582
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1583
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1533
1584
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1534
1585
  }
1535
1586
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1537,12 +1588,12 @@ static bool ggml_metal_graph_compute(
1537
1588
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1538
1589
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1539
1590
  }
1540
- else if (src0t == GGML_TYPE_IQ3_XXS) {
1541
- const int mem_size = 256*4+128;
1591
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1592
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1542
1593
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1543
1594
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1544
1595
  }
1545
- else if (src0t == GGML_TYPE_IQ4_NL) {
1596
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1546
1597
  const int mem_size = 32*sizeof(float);
1547
1598
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
1599
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1640,8 +1691,11 @@ static bool ggml_metal_graph_compute(
1640
1691
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1641
1692
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1642
1693
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1694
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1695
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1643
1696
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
1697
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1698
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1645
1699
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1646
1700
  }
1647
1701
 
@@ -1779,6 +1833,18 @@ static bool ggml_metal_graph_compute(
1779
1833
  nth1 = 16;
1780
1834
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1781
1835
  } break;
1836
+ case GGML_TYPE_IQ3_S:
1837
+ {
1838
+ nth0 = 4;
1839
+ nth1 = 16;
1840
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
1841
+ } break;
1842
+ case GGML_TYPE_IQ2_S:
1843
+ {
1844
+ nth0 = 4;
1845
+ nth1 = 16;
1846
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
1847
+ } break;
1782
1848
  case GGML_TYPE_IQ1_S:
1783
1849
  {
1784
1850
  nth0 = 4;
@@ -1791,6 +1857,12 @@ static bool ggml_metal_graph_compute(
1791
1857
  nth1 = 16;
1792
1858
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
1859
  } break;
1860
+ case GGML_TYPE_IQ4_XS:
1861
+ {
1862
+ nth0 = 4;
1863
+ nth1 = 16;
1864
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1865
+ } break;
1794
1866
  default:
1795
1867
  {
1796
1868
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1839,9 +1911,9 @@ static bool ggml_metal_graph_compute(
1839
1911
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1840
1912
  }
1841
1913
 
1842
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1843
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1844
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1914
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1915
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1916
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1845
1917
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1846
1918
  }
1847
1919
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1849,12 +1921,12 @@ static bool ggml_metal_graph_compute(
1849
1921
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1850
1922
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1851
1923
  }
1852
- else if (src2t == GGML_TYPE_IQ3_XXS) {
1853
- const int mem_size = 256*4+128;
1924
+ else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1925
+ const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1854
1926
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1855
1927
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1856
1928
  }
1857
- else if (src2t == GGML_TYPE_IQ4_NL) {
1929
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1858
1930
  const int mem_size = 32*sizeof(float);
1859
1931
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
1932
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1900,8 +1972,11 @@ static bool ggml_metal_graph_compute(
1900
1972
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1901
1973
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1902
1974
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1975
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
1976
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1903
1977
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
1978
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1979
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
1905
1980
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1906
1981
  default: GGML_ASSERT(false && "not implemented");
1907
1982
  }
@@ -2237,8 +2312,8 @@ static bool ggml_metal_graph_compute(
2237
2312
  id<MTLComputePipelineState> pipeline = nil;
2238
2313
 
2239
2314
  switch (order) {
2240
- case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2241
- case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2315
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2316
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2242
2317
  default: GGML_ASSERT(false);
2243
2318
  };
2244
2319
 
@@ -2696,6 +2771,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
2696
2771
  ggml_metal_log_user_data = user_data;
2697
2772
  }
2698
2773
 
2774
+ static ggml_guid_t ggml_backend_metal_guid(void) {
2775
+ static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
2776
+ return &guid;
2777
+ }
2778
+
2699
2779
  ggml_backend_t ggml_backend_metal_init(void) {
2700
2780
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2701
2781
 
@@ -2706,6 +2786,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2706
2786
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2707
2787
 
2708
2788
  *metal_backend = (struct ggml_backend) {
2789
+ /* .guid = */ ggml_backend_metal_guid(),
2709
2790
  /* .interface = */ ggml_backend_metal_i,
2710
2791
  /* .context = */ ctx,
2711
2792
  };
@@ -2714,7 +2795,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2714
2795
  }
2715
2796
 
2716
2797
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2717
- return backend && backend->iface.get_name == ggml_backend_metal_name;
2798
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
2718
2799
  }
2719
2800
 
2720
2801
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {