llama_cpp 0.12.7 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
1927
1927
  return ggml_backend_kompute_buffer_type(ctx->device);
1928
1928
  }
1929
1929
 
1930
- static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1930
+ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1931
1931
  auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1932
1932
  ggml_vk_graph_compute(ctx, cgraph);
1933
- return true;
1933
+ return GGML_STATUS_SUCCESS;
1934
1934
  }
1935
1935
 
1936
1936
  static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
1954
  };
1955
1955
 
1956
+ static ggml_guid_t ggml_backend_kompute_guid() {
1957
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
1958
+ return &guid;
1959
+ }
1960
+
1956
1961
  ggml_backend_t ggml_backend_kompute_init(int device) {
1957
1962
  GGML_ASSERT(s_kompute_context == nullptr);
1958
1963
  s_kompute_context = new ggml_kompute_context(device);
1959
1964
 
1960
1965
  ggml_backend_t kompute_backend = new ggml_backend {
1966
+ /* .guid = */ ggml_backend_kompute_guid(),
1961
1967
  /* .interface = */ kompute_backend_i,
1962
1968
  /* .context = */ s_kompute_context,
1963
1969
  };
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
1966
1972
  }
1967
1973
 
1968
1974
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
1969
- return backend && backend->iface.get_name == ggml_backend_kompute_name;
1975
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
1970
1976
  }
1971
1977
 
1972
1978
  static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
@@ -61,8 +61,11 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
66
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
67
70
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
68
71
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,8 +88,11 @@ enum ggml_metal_kernel_type {
85
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
86
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
87
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
88
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
90
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
91
97
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
92
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -105,8 +111,11 @@ enum ggml_metal_kernel_type {
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
107
113
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
108
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
110
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
111
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
112
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -122,8 +131,11 @@ enum ggml_metal_kernel_type {
122
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
123
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
124
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
125
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
127
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -139,8 +151,11 @@ enum ggml_metal_kernel_type {
139
151
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
152
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
153
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
142
156
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
157
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
144
159
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
145
160
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
146
161
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -148,6 +163,8 @@ enum ggml_metal_kernel_type {
148
163
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
149
164
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
150
165
  GGML_METAL_KERNEL_TYPE_PAD_F32,
166
+ GGML_METAL_KERNEL_TYPE_ARANGE_F32,
167
+ GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
151
168
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
152
169
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
153
170
  GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -452,8 +469,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
452
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
453
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
454
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
455
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
476
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
457
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
458
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
459
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -476,8 +496,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
476
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
477
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
478
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
479
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
481
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
482
505
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
483
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -496,8 +519,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
496
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
497
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
498
521
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
522
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
523
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
499
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
501
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
502
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
503
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -513,8 +539,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
513
539
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
514
540
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
515
541
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
542
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
543
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
516
544
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
546
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
518
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
519
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
520
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -530,8 +559,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
530
559
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
531
560
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
532
561
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
562
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
563
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
533
564
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
566
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
535
567
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
536
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
537
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -539,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
539
571
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
540
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
541
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
574
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
575
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
542
576
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
543
577
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
544
578
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@@ -667,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
667
701
  return false;
668
702
  case GGML_OP_UPSCALE:
669
703
  case GGML_OP_PAD:
704
+ case GGML_OP_ARANGE:
705
+ case GGML_OP_TIMESTEP_EMBEDDING:
670
706
  case GGML_OP_ARGSORT:
671
707
  case GGML_OP_LEAKY_RELU:
672
708
  return true;
@@ -712,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
712
748
  }
713
749
  }
714
750
 
715
- static bool ggml_metal_graph_compute(
751
+ static enum ggml_status ggml_metal_graph_compute(
716
752
  struct ggml_metal_context * ctx,
717
753
  struct ggml_cgraph * gf) {
718
754
 
@@ -1061,7 +1097,8 @@ static bool ggml_metal_graph_compute(
1061
1097
  {
1062
1098
  GGML_ASSERT(ggml_is_contiguous(src0));
1063
1099
 
1064
- const float scale = *(const float *) dst->op_params;
1100
+ float scale;
1101
+ memcpy(&scale, dst->op_params, sizeof(scale));
1065
1102
 
1066
1103
  int64_t n = ggml_nelements(dst);
1067
1104
 
@@ -1220,11 +1257,15 @@ static bool ggml_metal_graph_compute(
1220
1257
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1221
1258
  }
1222
1259
 
1223
- const float scale = ((float *) dst->op_params)[0];
1224
- const float max_bias = ((float *) dst->op_params)[1];
1260
+ float scale;
1261
+ float max_bias;
1262
+
1263
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1264
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1225
1265
 
1226
1266
  const int64_t nrows_x = ggml_nrows(src0);
1227
1267
  const int64_t nrows_y = src0->ne[1];
1268
+
1228
1269
  const uint32_t n_head_kv = nrows_x/nrows_y;
1229
1270
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1230
1271
 
@@ -1347,8 +1388,11 @@ static bool ggml_metal_graph_compute(
1347
1388
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1348
1389
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1349
1390
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1391
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1392
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1350
1393
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
1394
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1395
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1352
1396
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1353
1397
  }
1354
1398
 
@@ -1483,6 +1527,18 @@ static bool ggml_metal_graph_compute(
1483
1527
  nth1 = 16;
1484
1528
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1485
1529
  } break;
1530
+ case GGML_TYPE_IQ3_S:
1531
+ {
1532
+ nth0 = 4;
1533
+ nth1 = 16;
1534
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1535
+ } break;
1536
+ case GGML_TYPE_IQ2_S:
1537
+ {
1538
+ nth0 = 4;
1539
+ nth1 = 16;
1540
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1541
+ } break;
1486
1542
  case GGML_TYPE_IQ1_S:
1487
1543
  {
1488
1544
  nth0 = 4;
@@ -1495,6 +1551,12 @@ static bool ggml_metal_graph_compute(
1495
1551
  nth1 = 16;
1496
1552
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
1553
  } break;
1554
+ case GGML_TYPE_IQ4_XS:
1555
+ {
1556
+ nth0 = 4;
1557
+ nth1 = 16;
1558
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1559
+ } break;
1498
1560
  default:
1499
1561
  {
1500
1562
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1527,9 +1589,9 @@ static bool ggml_metal_graph_compute(
1527
1589
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1528
1590
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1529
1591
 
1530
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1531
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1532
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1592
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1593
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1594
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1533
1595
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1534
1596
  }
1535
1597
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1537,12 +1599,12 @@ static bool ggml_metal_graph_compute(
1537
1599
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1538
1600
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1539
1601
  }
1540
- else if (src0t == GGML_TYPE_IQ3_XXS) {
1541
- const int mem_size = 256*4+128;
1602
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1603
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1542
1604
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1543
1605
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1544
1606
  }
1545
- else if (src0t == GGML_TYPE_IQ4_NL) {
1607
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1546
1608
  const int mem_size = 32*sizeof(float);
1547
1609
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
1610
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1640,8 +1702,11 @@ static bool ggml_metal_graph_compute(
1640
1702
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1641
1703
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1642
1704
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1705
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1706
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1643
1707
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
1708
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1709
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1645
1710
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1646
1711
  }
1647
1712
 
@@ -1779,6 +1844,18 @@ static bool ggml_metal_graph_compute(
1779
1844
  nth1 = 16;
1780
1845
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1781
1846
  } break;
1847
+ case GGML_TYPE_IQ3_S:
1848
+ {
1849
+ nth0 = 4;
1850
+ nth1 = 16;
1851
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
1852
+ } break;
1853
+ case GGML_TYPE_IQ2_S:
1854
+ {
1855
+ nth0 = 4;
1856
+ nth1 = 16;
1857
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
1858
+ } break;
1782
1859
  case GGML_TYPE_IQ1_S:
1783
1860
  {
1784
1861
  nth0 = 4;
@@ -1791,6 +1868,12 @@ static bool ggml_metal_graph_compute(
1791
1868
  nth1 = 16;
1792
1869
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
1870
  } break;
1871
+ case GGML_TYPE_IQ4_XS:
1872
+ {
1873
+ nth0 = 4;
1874
+ nth1 = 16;
1875
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1876
+ } break;
1794
1877
  default:
1795
1878
  {
1796
1879
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1839,9 +1922,9 @@ static bool ggml_metal_graph_compute(
1839
1922
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1840
1923
  }
1841
1924
 
1842
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1843
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1844
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1925
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1926
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1927
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1845
1928
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1846
1929
  }
1847
1930
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1849,12 +1932,12 @@ static bool ggml_metal_graph_compute(
1849
1932
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1850
1933
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1851
1934
  }
1852
- else if (src2t == GGML_TYPE_IQ3_XXS) {
1853
- const int mem_size = 256*4+128;
1935
+ else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1936
+ const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1854
1937
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1855
1938
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1856
1939
  }
1857
- else if (src2t == GGML_TYPE_IQ4_NL) {
1940
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1858
1941
  const int mem_size = 32*sizeof(float);
1859
1942
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
1943
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1900,8 +1983,11 @@ static bool ggml_metal_graph_compute(
1900
1983
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1901
1984
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1902
1985
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1986
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
1987
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1903
1988
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
1989
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1990
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
1905
1991
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1906
1992
  default: GGML_ASSERT(false && "not implemented");
1907
1993
  }
@@ -2011,6 +2097,7 @@ static bool ggml_metal_graph_compute(
2011
2097
 
2012
2098
  //const int n_past = ((int32_t *) dst->op_params)[0];
2013
2099
  const int n_head = ((int32_t *) dst->op_params)[1];
2100
+
2014
2101
  float max_bias;
2015
2102
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2016
2103
 
@@ -2225,6 +2312,50 @@ static bool ggml_metal_graph_compute(
2225
2312
 
2226
2313
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2314
  } break;
2315
+ case GGML_OP_ARANGE:
2316
+ {
2317
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2318
+
2319
+ float start;
2320
+ float step;
2321
+
2322
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2323
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2324
+
2325
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2326
+
2327
+ [encoder setComputePipelineState:pipeline];
2328
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2329
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2330
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
2331
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
2332
+
2333
+ const int nth = MIN(1024, ne0);
2334
+
2335
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2336
+ } break;
2337
+ case GGML_OP_TIMESTEP_EMBEDDING:
2338
+ {
2339
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2340
+
2341
+ const int dim = dst->op_params[0];
2342
+ const int max_period = dst->op_params[1];
2343
+
2344
+ const int half = dim / 2;
2345
+
2346
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2347
+
2348
+ [encoder setComputePipelineState:pipeline];
2349
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2351
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2352
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2353
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2354
+
2355
+ const int nth = MIN(1024, half);
2356
+
2357
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2358
+ } break;
2228
2359
  case GGML_OP_ARGSORT:
2229
2360
  {
2230
2361
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -2237,8 +2368,8 @@ static bool ggml_metal_graph_compute(
2237
2368
  id<MTLComputePipelineState> pipeline = nil;
2238
2369
 
2239
2370
  switch (order) {
2240
- case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2241
- case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2371
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2372
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2242
2373
  default: GGML_ASSERT(false);
2243
2374
  };
2244
2375
 
@@ -2353,7 +2484,7 @@ static bool ggml_metal_graph_compute(
2353
2484
  MTLCommandBufferStatus status = [command_buffer status];
2354
2485
  if (status != MTLCommandBufferStatusCompleted) {
2355
2486
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2356
- return false;
2487
+ return GGML_STATUS_FAILED;
2357
2488
  }
2358
2489
  }
2359
2490
 
@@ -2362,7 +2493,7 @@ static bool ggml_metal_graph_compute(
2362
2493
  }
2363
2494
 
2364
2495
  }
2365
- return true;
2496
+ return GGML_STATUS_SUCCESS;
2366
2497
  }
2367
2498
 
2368
2499
  ////////////////////////////////////////////////////////////////////////////////
@@ -2664,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
2664
2795
  UNUSED(backend);
2665
2796
  }
2666
2797
 
2667
- GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2798
+ GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2668
2799
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2669
2800
 
2670
2801
  return ggml_metal_graph_compute(metal_ctx, cgraph);
@@ -2696,6 +2827,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
2696
2827
  ggml_metal_log_user_data = user_data;
2697
2828
  }
2698
2829
 
2830
+ static ggml_guid_t ggml_backend_metal_guid(void) {
2831
+ static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
2832
+ return &guid;
2833
+ }
2834
+
2699
2835
  ggml_backend_t ggml_backend_metal_init(void) {
2700
2836
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2701
2837
 
@@ -2706,6 +2842,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2706
2842
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2707
2843
 
2708
2844
  *metal_backend = (struct ggml_backend) {
2845
+ /* .guid = */ ggml_backend_metal_guid(),
2709
2846
  /* .interface = */ ggml_backend_metal_i,
2710
2847
  /* .context = */ ctx,
2711
2848
  };
@@ -2714,7 +2851,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2714
2851
  }
2715
2852
 
2716
2853
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2717
- return backend && backend->iface.get_name == ggml_backend_metal_name;
2854
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
2718
2855
  }
2719
2856
 
2720
2857
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {