llama_cpp 0.12.7 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
1927
1927
  return ggml_backend_kompute_buffer_type(ctx->device);
1928
1928
  }
1929
1929
 
1930
- static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1930
+ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1931
1931
  auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1932
1932
  ggml_vk_graph_compute(ctx, cgraph);
1933
- return true;
1933
+ return GGML_STATUS_SUCCESS;
1934
1934
  }
1935
1935
 
1936
1936
  static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
1953
1953
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1954
1954
  };
1955
1955
 
1956
+ static ggml_guid_t ggml_backend_kompute_guid() {
1957
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
1958
+ return &guid;
1959
+ }
1960
+
1956
1961
  ggml_backend_t ggml_backend_kompute_init(int device) {
1957
1962
  GGML_ASSERT(s_kompute_context == nullptr);
1958
1963
  s_kompute_context = new ggml_kompute_context(device);
1959
1964
 
1960
1965
  ggml_backend_t kompute_backend = new ggml_backend {
1966
+ /* .guid = */ ggml_backend_kompute_guid(),
1961
1967
  /* .interface = */ kompute_backend_i,
1962
1968
  /* .context = */ s_kompute_context,
1963
1969
  };
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
1966
1972
  }
1967
1973
 
1968
1974
  bool ggml_backend_is_kompute(ggml_backend_t backend) {
1969
- return backend && backend->iface.get_name == ggml_backend_kompute_name;
1975
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
1970
1976
  }
1971
1977
 
1972
1978
  static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
@@ -61,8 +61,11 @@ enum ggml_metal_kernel_type {
61
61
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
62
62
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
64
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
65
67
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
66
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
67
70
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
68
71
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -85,8 +88,11 @@ enum ggml_metal_kernel_type {
85
88
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
86
89
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
87
90
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
88
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
89
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
90
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
91
97
  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
92
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -105,8 +111,11 @@ enum ggml_metal_kernel_type {
105
111
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
106
112
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
107
113
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
114
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
108
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
109
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
110
119
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
111
120
  GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
112
121
  GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -122,8 +131,11 @@ enum ggml_metal_kernel_type {
122
131
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
123
132
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
124
133
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
134
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
125
136
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
126
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
127
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
128
140
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
129
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -139,8 +151,11 @@ enum ggml_metal_kernel_type {
139
151
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
140
152
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
141
153
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
142
156
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
143
157
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
144
159
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
145
160
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
146
161
  GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -148,6 +163,8 @@ enum ggml_metal_kernel_type {
148
163
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
149
164
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
150
165
  GGML_METAL_KERNEL_TYPE_PAD_F32,
166
+ GGML_METAL_KERNEL_TYPE_ARANGE_F32,
167
+ GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
151
168
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
152
169
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
153
170
  GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -452,8 +469,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
452
469
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
453
470
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
454
471
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
472
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
473
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
455
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
456
475
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
476
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
457
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
458
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
459
479
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -476,8 +496,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
476
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
477
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
478
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
479
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
480
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
481
504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
482
505
  //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
483
506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -496,8 +519,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
496
519
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
497
520
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
498
521
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
522
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
523
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
499
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
500
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
501
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
502
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
503
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -513,8 +539,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
513
539
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
514
540
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
515
541
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
542
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
543
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
516
544
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
517
545
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
546
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
518
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
519
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
520
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -530,8 +559,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
530
559
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
531
560
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
532
561
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
562
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
563
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
533
564
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
534
565
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
566
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
535
567
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
536
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
537
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -539,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
539
571
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
540
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
541
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
574
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
575
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
542
576
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
543
577
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
544
578
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@@ -667,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
667
701
  return false;
668
702
  case GGML_OP_UPSCALE:
669
703
  case GGML_OP_PAD:
704
+ case GGML_OP_ARANGE:
705
+ case GGML_OP_TIMESTEP_EMBEDDING:
670
706
  case GGML_OP_ARGSORT:
671
707
  case GGML_OP_LEAKY_RELU:
672
708
  return true;
@@ -712,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
712
748
  }
713
749
  }
714
750
 
715
- static bool ggml_metal_graph_compute(
751
+ static enum ggml_status ggml_metal_graph_compute(
716
752
  struct ggml_metal_context * ctx,
717
753
  struct ggml_cgraph * gf) {
718
754
 
@@ -1061,7 +1097,8 @@ static bool ggml_metal_graph_compute(
1061
1097
  {
1062
1098
  GGML_ASSERT(ggml_is_contiguous(src0));
1063
1099
 
1064
- const float scale = *(const float *) dst->op_params;
1100
+ float scale;
1101
+ memcpy(&scale, dst->op_params, sizeof(scale));
1065
1102
 
1066
1103
  int64_t n = ggml_nelements(dst);
1067
1104
 
@@ -1220,11 +1257,15 @@ static bool ggml_metal_graph_compute(
1220
1257
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1221
1258
  }
1222
1259
 
1223
- const float scale = ((float *) dst->op_params)[0];
1224
- const float max_bias = ((float *) dst->op_params)[1];
1260
+ float scale;
1261
+ float max_bias;
1262
+
1263
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1264
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1225
1265
 
1226
1266
  const int64_t nrows_x = ggml_nrows(src0);
1227
1267
  const int64_t nrows_y = src0->ne[1];
1268
+
1228
1269
  const uint32_t n_head_kv = nrows_x/nrows_y;
1229
1270
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1230
1271
 
@@ -1347,8 +1388,11 @@ static bool ggml_metal_graph_compute(
1347
1388
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1348
1389
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1349
1390
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1391
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1392
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1350
1393
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1351
1394
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1395
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1352
1396
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
1353
1397
  }
1354
1398
 
@@ -1483,6 +1527,18 @@ static bool ggml_metal_graph_compute(
1483
1527
  nth1 = 16;
1484
1528
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1485
1529
  } break;
1530
+ case GGML_TYPE_IQ3_S:
1531
+ {
1532
+ nth0 = 4;
1533
+ nth1 = 16;
1534
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1535
+ } break;
1536
+ case GGML_TYPE_IQ2_S:
1537
+ {
1538
+ nth0 = 4;
1539
+ nth1 = 16;
1540
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1541
+ } break;
1486
1542
  case GGML_TYPE_IQ1_S:
1487
1543
  {
1488
1544
  nth0 = 4;
@@ -1495,6 +1551,12 @@ static bool ggml_metal_graph_compute(
1495
1551
  nth1 = 16;
1496
1552
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1497
1553
  } break;
1554
+ case GGML_TYPE_IQ4_XS:
1555
+ {
1556
+ nth0 = 4;
1557
+ nth1 = 16;
1558
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1559
+ } break;
1498
1560
  default:
1499
1561
  {
1500
1562
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1527,9 +1589,9 @@ static bool ggml_metal_graph_compute(
1527
1589
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1528
1590
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1529
1591
 
1530
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1531
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1532
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
1592
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1593
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1594
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1533
1595
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1534
1596
  }
1535
1597
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1537,12 +1599,12 @@ static bool ggml_metal_graph_compute(
1537
1599
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1538
1600
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1539
1601
  }
1540
- else if (src0t == GGML_TYPE_IQ3_XXS) {
1541
- const int mem_size = 256*4+128;
1602
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1603
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1542
1604
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1543
1605
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1544
1606
  }
1545
- else if (src0t == GGML_TYPE_IQ4_NL) {
1607
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1546
1608
  const int mem_size = 32*sizeof(float);
1547
1609
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1548
1610
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1640,8 +1702,11 @@ static bool ggml_metal_graph_compute(
1640
1702
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
1641
1703
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
1642
1704
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1705
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1706
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1643
1707
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1644
1708
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1709
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1645
1710
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1646
1711
  }
1647
1712
 
@@ -1779,6 +1844,18 @@ static bool ggml_metal_graph_compute(
1779
1844
  nth1 = 16;
1780
1845
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
1781
1846
  } break;
1847
+ case GGML_TYPE_IQ3_S:
1848
+ {
1849
+ nth0 = 4;
1850
+ nth1 = 16;
1851
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
1852
+ } break;
1853
+ case GGML_TYPE_IQ2_S:
1854
+ {
1855
+ nth0 = 4;
1856
+ nth1 = 16;
1857
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
1858
+ } break;
1782
1859
  case GGML_TYPE_IQ1_S:
1783
1860
  {
1784
1861
  nth0 = 4;
@@ -1791,6 +1868,12 @@ static bool ggml_metal_graph_compute(
1791
1868
  nth1 = 16;
1792
1869
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1793
1870
  } break;
1871
+ case GGML_TYPE_IQ4_XS:
1872
+ {
1873
+ nth0 = 4;
1874
+ nth1 = 16;
1875
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1876
+ } break;
1794
1877
  default:
1795
1878
  {
1796
1879
  GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1839,9 +1922,9 @@ static bool ggml_metal_graph_compute(
1839
1922
  [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1840
1923
  }
1841
1924
 
1842
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1843
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1844
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
1925
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1926
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1927
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1845
1928
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1846
1929
  }
1847
1930
  else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1849,12 +1932,12 @@ static bool ggml_metal_graph_compute(
1849
1932
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1850
1933
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1851
1934
  }
1852
- else if (src2t == GGML_TYPE_IQ3_XXS) {
1853
- const int mem_size = 256*4+128;
1935
+ else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1936
+ const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1854
1937
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1855
1938
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1856
1939
  }
1857
- else if (src2t == GGML_TYPE_IQ4_NL) {
1940
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1858
1941
  const int mem_size = 32*sizeof(float);
1859
1942
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1860
1943
  [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1900,8 +1983,11 @@ static bool ggml_metal_graph_compute(
1900
1983
  case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
1901
1984
  case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
1902
1985
  case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1986
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
1987
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
1903
1988
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
1904
1989
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
1990
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
1905
1991
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
1906
1992
  default: GGML_ASSERT(false && "not implemented");
1907
1993
  }
@@ -2011,6 +2097,7 @@ static bool ggml_metal_graph_compute(
2011
2097
 
2012
2098
  //const int n_past = ((int32_t *) dst->op_params)[0];
2013
2099
  const int n_head = ((int32_t *) dst->op_params)[1];
2100
+
2014
2101
  float max_bias;
2015
2102
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2016
2103
 
@@ -2225,6 +2312,50 @@ static bool ggml_metal_graph_compute(
2225
2312
 
2226
2313
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2314
  } break;
2315
+ case GGML_OP_ARANGE:
2316
+ {
2317
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2318
+
2319
+ float start;
2320
+ float step;
2321
+
2322
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2323
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2324
+
2325
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2326
+
2327
+ [encoder setComputePipelineState:pipeline];
2328
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2329
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2330
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
2331
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
2332
+
2333
+ const int nth = MIN(1024, ne0);
2334
+
2335
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2336
+ } break;
2337
+ case GGML_OP_TIMESTEP_EMBEDDING:
2338
+ {
2339
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2340
+
2341
+ const int dim = dst->op_params[0];
2342
+ const int max_period = dst->op_params[1];
2343
+
2344
+ const int half = dim / 2;
2345
+
2346
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2347
+
2348
+ [encoder setComputePipelineState:pipeline];
2349
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2351
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2352
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2353
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2354
+
2355
+ const int nth = MIN(1024, half);
2356
+
2357
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2358
+ } break;
2228
2359
  case GGML_OP_ARGSORT:
2229
2360
  {
2230
2361
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -2237,8 +2368,8 @@ static bool ggml_metal_graph_compute(
2237
2368
  id<MTLComputePipelineState> pipeline = nil;
2238
2369
 
2239
2370
  switch (order) {
2240
- case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2241
- case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2371
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
2372
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
2242
2373
  default: GGML_ASSERT(false);
2243
2374
  };
2244
2375
 
@@ -2353,7 +2484,7 @@ static bool ggml_metal_graph_compute(
2353
2484
  MTLCommandBufferStatus status = [command_buffer status];
2354
2485
  if (status != MTLCommandBufferStatusCompleted) {
2355
2486
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2356
- return false;
2487
+ return GGML_STATUS_FAILED;
2357
2488
  }
2358
2489
  }
2359
2490
 
@@ -2362,7 +2493,7 @@ static bool ggml_metal_graph_compute(
2362
2493
  }
2363
2494
 
2364
2495
  }
2365
- return true;
2496
+ return GGML_STATUS_SUCCESS;
2366
2497
  }
2367
2498
 
2368
2499
  ////////////////////////////////////////////////////////////////////////////////
@@ -2664,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
2664
2795
  UNUSED(backend);
2665
2796
  }
2666
2797
 
2667
- GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2798
+ GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2668
2799
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2669
2800
 
2670
2801
  return ggml_metal_graph_compute(metal_ctx, cgraph);
@@ -2696,6 +2827,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
2696
2827
  ggml_metal_log_user_data = user_data;
2697
2828
  }
2698
2829
 
2830
+ static ggml_guid_t ggml_backend_metal_guid(void) {
2831
+ static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
2832
+ return &guid;
2833
+ }
2834
+
2699
2835
  ggml_backend_t ggml_backend_metal_init(void) {
2700
2836
  struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2701
2837
 
@@ -2706,6 +2842,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2706
2842
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
2707
2843
 
2708
2844
  *metal_backend = (struct ggml_backend) {
2845
+ /* .guid = */ ggml_backend_metal_guid(),
2709
2846
  /* .interface = */ ggml_backend_metal_i,
2710
2847
  /* .context = */ ctx,
2711
2848
  };
@@ -2714,7 +2851,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
2714
2851
  }
2715
2852
 
2716
2853
  bool ggml_backend_is_metal(ggml_backend_t backend) {
2717
- return backend && backend->iface.get_name == ggml_backend_metal_name;
2854
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
2718
2855
  }
2719
2856
 
2720
2857
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {