llama_cpp 0.14.2 → 0.14.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,29 +17,17 @@ extern "C" {
17
17
 
18
18
  #define GGML_CUDA_MAX_DEVICES 16
19
19
 
20
- // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
21
- GGML_API GGML_CALL void ggml_init_cublas(void);
22
-
23
- // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24
- GGML_API GGML_CALL bool ggml_cublas_loaded(void);
25
-
26
- GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
27
- GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr);
28
-
29
- GGML_API GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
30
- GGML_API GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
31
-
32
- GGML_API GGML_CALL int ggml_cuda_get_device_count(void);
33
- GGML_API GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
-
35
20
  // backend API
36
21
  GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
37
22
 
38
23
  GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
39
24
 
25
+ // device buffer
40
26
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
27
+
41
28
  // split tensor buffer that splits matrices by rows across multiple devices
42
29
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
30
+
43
31
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
44
32
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
45
33
 
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
47
35
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
48
36
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
49
37
 
38
+ GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
39
+ GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
40
+
50
41
  #ifdef __cplusplus
51
42
  }
52
43
  #endif
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1430
1430
  struct ggml_tensor * dst = gf->nodes[i];
1431
1431
  GGML_ASSERT(dst->data != nullptr);
1432
1432
 
1433
+ if (ggml_is_empty(dst)) {
1434
+ continue;
1435
+ }
1436
+
1433
1437
  switch (dst->op) {
1434
1438
  case GGML_OP_NONE:
1435
1439
  case GGML_OP_RESHAPE:
@@ -1951,6 +1955,7 @@ static struct ggml_backend_i kompute_backend_i = {
1951
1955
  /* .graph_plan_compute = */ NULL,
1952
1956
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1953
1957
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1958
+ /* .offload_op = */ NULL,
1954
1959
  /* .event_new = */ NULL,
1955
1960
  /* .event_free = */ NULL,
1956
1961
  /* .event_record = */ NULL,
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
64
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
67
68
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
69
70
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
91
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
94
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
97
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
96
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
114
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
118
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
117
120
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
121
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
134
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
138
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
137
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
154
158
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
159
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
160
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
161
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
157
162
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
163
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159
164
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -173,8 +178,9 @@ enum ggml_metal_kernel_type {
173
178
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
174
179
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
175
180
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
176
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
181
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
182
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
183
+ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
178
184
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179
185
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180
186
  GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -489,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
489
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
490
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
491
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
492
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
493
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
494
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -516,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
516
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
517
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
518
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
519
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
520
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
521
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -539,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
539
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
540
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
541
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
550
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
542
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
543
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
544
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -559,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
559
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
560
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
561
570
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
571
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
562
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
563
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
564
574
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -579,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
579
589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
580
590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
581
591
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
582
593
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
583
594
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
584
595
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -598,8 +609,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
598
609
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
599
610
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
600
611
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
601
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
602
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
612
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
613
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
614
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
603
615
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604
616
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
605
617
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@@ -739,6 +751,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
739
751
  case GGML_TYPE_Q8_0:
740
752
  case GGML_TYPE_Q4_0:
741
753
  case GGML_TYPE_Q4_1:
754
+ case GGML_TYPE_Q5_0:
755
+ case GGML_TYPE_Q5_1:
756
+ case GGML_TYPE_IQ4_NL:
742
757
  return true;
743
758
  default:
744
759
  return false;
@@ -832,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
832
847
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
833
848
  struct ggml_tensor * dst = gf->nodes[i];
834
849
 
850
+ if (ggml_is_empty(dst)) {
851
+ continue;
852
+ }
853
+
835
854
  switch (dst->op) {
836
855
  case GGML_OP_NONE:
837
856
  case GGML_OP_RESHAPE:
@@ -1387,6 +1406,14 @@ static enum ggml_status ggml_metal_graph_compute(
1387
1406
  (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1388
1407
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1389
1408
 
1409
+ // some Metal matrix data types require aligned pointers
1410
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1411
+ switch (src0->type) {
1412
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1413
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1414
+ default: break;
1415
+ }
1416
+
1390
1417
  id<MTLComputePipelineState> pipeline = nil;
1391
1418
 
1392
1419
  switch (src0->type) {
@@ -1408,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
1408
1435
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1409
1436
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1410
1437
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1438
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1411
1439
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1412
1440
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1413
1441
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1562,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
1562
1590
  nth1 = 16;
1563
1591
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1564
1592
  } break;
1593
+ case GGML_TYPE_IQ1_M:
1594
+ {
1595
+ nth0 = 4;
1596
+ nth1 = 16;
1597
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1598
+ } break;
1565
1599
  case GGML_TYPE_IQ4_NL:
1566
1600
  {
1567
1601
  nth0 = 4;
@@ -1606,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
1606
1640
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1607
1641
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1608
1642
 
1609
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1610
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1611
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1643
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1644
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1645
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1612
1646
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1613
1647
  }
1614
1648
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1651,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
1651
1685
  {
1652
1686
  //GGML_ASSERT(ne00 == ne10);
1653
1687
  //GGML_ASSERT(ne03 == ne13);
1654
-
1655
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1656
-
1657
- const int n_as = ((int32_t *) dst->op_params)[1];
1658
-
1659
- // TODO: make this more general
1660
- GGML_ASSERT(n_as <= 8);
1688
+ const int n_as = src0->ne[2];
1661
1689
 
1662
1690
  // max size of the src1ids array in the kernel shared buffer
1663
1691
  GGML_ASSERT(ne11 <= 4096);
1664
1692
 
1665
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1666
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1667
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1668
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1693
+ // src2 = ids
1694
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1695
+ const int64_t ne21 = src2->ne[1];
1696
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1698
+
1699
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1700
+ const uint64_t nb21 = src2->nb[1];
1701
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1702
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1669
1703
 
1670
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1671
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1672
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1673
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1704
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1674
1705
 
1675
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1706
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
1676
1707
 
1677
- GGML_ASSERT(!ggml_is_transposed(src2));
1708
+ GGML_ASSERT(!ggml_is_transposed(src0));
1678
1709
  GGML_ASSERT(!ggml_is_transposed(src1));
1679
1710
 
1680
1711
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1681
1712
 
1682
- const uint r2 = ne12/ne22;
1683
- const uint r3 = ne13/ne23;
1684
-
1685
1713
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1686
1714
  // to the matrix-vector kernel
1687
1715
  int ne11_mm_min = n_as;
@@ -1689,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
1689
1717
  const int idx = ((int32_t *) dst->op_params)[0];
1690
1718
 
1691
1719
  // batch size
1692
- GGML_ASSERT(ne01 == ne11);
1720
+ GGML_ASSERT(ne21 == ne11); // ?
1721
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
+ const uint r2 = 1;
1723
+ const uint r3 = 1;
1693
1724
 
1694
1725
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1695
1726
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1698,12 +1729,20 @@ static enum ggml_status ggml_metal_graph_compute(
1698
1729
  // indirect matrix multiplication
1699
1730
  // !!!
1700
1731
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1701
- ne20 % 32 == 0 && ne20 >= 64 &&
1732
+ ne00 % 32 == 0 && ne00 >= 64 &&
1702
1733
  ne11 > ne11_mm_min) {
1703
1734
 
1735
+ // some Metal matrix data types require aligned pointers
1736
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1737
+ switch (src0->type) {
1738
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1739
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1740
+ default: break;
1741
+ }
1742
+
1704
1743
  id<MTLComputePipelineState> pipeline = nil;
1705
1744
 
1706
- switch (src2->type) {
1745
+ switch (src0->type) {
1707
1746
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1708
1747
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1709
1748
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@@ -1722,6 +1761,7 @@ static enum ggml_status ggml_metal_graph_compute(
1722
1761
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1723
1762
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1724
1763
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1764
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1725
1765
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1726
1766
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1727
1767
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1731,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
1731
1771
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1732
1772
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1733
1773
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1734
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1735
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1736
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1737
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1738
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1739
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1740
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1741
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1742
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1743
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1744
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1745
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1746
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1747
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1748
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1749
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1750
- // TODO: how to make this an array? read Metal docs
1751
- for (int j = 0; j < 8; ++j) {
1752
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1753
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1754
-
1755
- size_t offs_src_cur = 0;
1756
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1757
-
1758
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1759
- }
1774
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1760
1791
 
1761
1792
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1762
1793
 
1763
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1794
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1764
1795
  } else {
1765
1796
  int nth0 = 32;
1766
1797
  int nth1 = 1;
@@ -1770,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
1770
1801
  id<MTLComputePipelineState> pipeline = nil;
1771
1802
 
1772
1803
  // use custom matrix x vector kernel
1773
- switch (src2t) {
1804
+ switch (src0t) {
1774
1805
  case GGML_TYPE_F32:
1775
1806
  {
1776
1807
  GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1879,6 +1910,12 @@ static enum ggml_status ggml_metal_graph_compute(
1879
1910
  nth1 = 16;
1880
1911
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1881
1912
  } break;
1913
+ case GGML_TYPE_IQ1_M:
1914
+ {
1915
+ nth0 = 4;
1916
+ nth1 = 16;
1917
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
1918
+ } break;
1882
1919
  case GGML_TYPE_IQ4_NL:
1883
1920
  {
1884
1921
  nth0 = 4;
@@ -1898,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
1898
1935
  }
1899
1936
  };
1900
1937
 
1901
- if (ggml_is_quantized(src2t)) {
1902
- GGML_ASSERT(ne20 >= nth0*nth1);
1938
+ if (ggml_is_quantized(src0t)) {
1939
+ GGML_ASSERT(ne00 >= nth0*nth1);
1903
1940
  }
1904
1941
 
1905
1942
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1908,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
1908
1945
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1909
1946
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1910
1947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1911
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1912
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1913
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1914
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1915
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1916
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1917
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1918
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1919
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1920
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1921
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1922
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1923
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1924
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1925
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1926
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1927
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1928
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1929
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1930
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1931
- // TODO: how to make this an array? read Metal docs
1932
- for (int j = 0; j < 8; ++j) {
1933
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1934
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1935
-
1936
- size_t offs_src_cur = 0;
1937
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1938
-
1939
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1940
- }
1941
-
1942
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1943
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1944
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1945
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1948
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1969
+
1970
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1946
1974
  }
1947
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1948
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1975
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1949
1977
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1950
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1978
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1951
1979
  }
1952
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1953
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1980
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1954
1982
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1955
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1956
1984
  }
1957
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1985
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1958
1986
  const int mem_size = 32*sizeof(float);
1959
1987
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1960
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1988
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1961
1989
  }
1962
- else if (src2t == GGML_TYPE_Q4_K) {
1963
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1990
+ else if (src0t == GGML_TYPE_Q4_K) {
1991
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1964
1992
  }
1965
- else if (src2t == GGML_TYPE_Q3_K) {
1993
+ else if (src0t == GGML_TYPE_Q3_K) {
1966
1994
  #ifdef GGML_QKK_64
1967
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1968
1996
  #else
1969
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1997
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1970
1998
  #endif
1971
1999
  }
1972
- else if (src2t == GGML_TYPE_Q5_K) {
1973
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2000
+ else if (src0t == GGML_TYPE_Q5_K) {
2001
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
2002
  }
1975
- else if (src2t == GGML_TYPE_Q6_K) {
1976
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2003
+ else if (src0t == GGML_TYPE_Q6_K) {
2004
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1977
2005
  } else {
1978
2006
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
1979
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2007
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1980
2008
  }
1981
2009
  }
1982
2010
  } break;
@@ -2003,6 +2031,7 @@ static enum ggml_status ggml_metal_graph_compute(
2003
2031
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2004
2032
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2005
2033
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2034
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2006
2035
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2007
2036
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2008
2037
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
@@ -2382,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
2382
2411
 
2383
2412
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2384
2413
 
2414
+ // bitonic sort requires the number of elements to be power of 2
2415
+ int64_t ne00_padded = 1;
2416
+ while (ne00_padded < ne00) {
2417
+ ne00_padded *= 2;
2418
+ }
2419
+
2420
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2421
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2422
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2423
+
2385
2424
  id<MTLComputePipelineState> pipeline = nil;
2386
2425
 
2387
2426
  switch (order) {
@@ -2391,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
2391
2430
  };
2392
2431
 
2393
2432
  [encoder setComputePipelineState:pipeline];
2394
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2395
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2396
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2433
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2437
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2397
2438
 
2398
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2439
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2399
2440
  } break;
2400
2441
  case GGML_OP_LEAKY_RELU:
2401
2442
  {
@@ -2431,13 +2472,14 @@ static enum ggml_status ggml_metal_graph_compute(
2431
2472
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2432
2473
 
2433
2474
  switch (dstt) {
2434
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2435
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2436
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2437
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2438
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2439
- //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2440
- //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2475
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2476
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2477
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2478
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2479
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2480
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2481
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2482
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2441
2483
  default: GGML_ASSERT(false && "not implemented");
2442
2484
  };
2443
2485
  } break;
@@ -2837,6 +2879,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
2837
2879
  /* .graph_plan_compute = */ NULL,
2838
2880
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
2839
2881
  /* .supports_op = */ ggml_backend_metal_supports_op,
2882
+ /* .offload_op = */ NULL,
2840
2883
  /* .event_new = */ NULL,
2841
2884
  /* .event_free = */ NULL,
2842
2885
  /* .event_record = */ NULL,