llama_cpp 0.14.2 → 0.14.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,29 +17,17 @@ extern "C" {
17
17
 
18
18
  #define GGML_CUDA_MAX_DEVICES 16
19
19
 
20
- // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
21
- GGML_API GGML_CALL void ggml_init_cublas(void);
22
-
23
- // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24
- GGML_API GGML_CALL bool ggml_cublas_loaded(void);
25
-
26
- GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
27
- GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr);
28
-
29
- GGML_API GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
30
- GGML_API GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
31
-
32
- GGML_API GGML_CALL int ggml_cuda_get_device_count(void);
33
- GGML_API GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
-
35
20
  // backend API
36
21
  GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
37
22
 
38
23
  GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
39
24
 
25
+ // device buffer
40
26
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
27
+
41
28
  // split tensor buffer that splits matrices by rows across multiple devices
42
29
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
30
+
43
31
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
44
32
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
45
33
 
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
47
35
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
48
36
  GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
49
37
 
38
+ GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
39
+ GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
40
+
50
41
  #ifdef __cplusplus
51
42
  }
52
43
  #endif
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1430
1430
  struct ggml_tensor * dst = gf->nodes[i];
1431
1431
  GGML_ASSERT(dst->data != nullptr);
1432
1432
 
1433
+ if (ggml_is_empty(dst)) {
1434
+ continue;
1435
+ }
1436
+
1433
1437
  switch (dst->op) {
1434
1438
  case GGML_OP_NONE:
1435
1439
  case GGML_OP_RESHAPE:
@@ -1951,6 +1955,7 @@ static struct ggml_backend_i kompute_backend_i = {
1951
1955
  /* .graph_plan_compute = */ NULL,
1952
1956
  /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1953
1957
  /* .supports_op = */ ggml_backend_kompute_supports_op,
1958
+ /* .offload_op = */ NULL,
1954
1959
  /* .event_new = */ NULL,
1955
1960
  /* .event_free = */ NULL,
1956
1961
  /* .event_record = */ NULL,
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
64
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
67
68
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
69
70
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
91
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
94
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
97
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
96
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
114
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
118
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
117
120
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
121
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
134
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
138
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
137
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
154
158
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
159
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
160
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
161
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
157
162
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
163
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159
164
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -173,8 +178,9 @@ enum ggml_metal_kernel_type {
173
178
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
174
179
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
175
180
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
176
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
177
- //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
181
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
182
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
183
+ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
178
184
  GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
179
185
  GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
180
186
  GGML_METAL_KERNEL_TYPE_CONCAT,
@@ -489,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
489
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
490
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
491
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
492
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
493
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
494
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -516,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
516
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
517
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
518
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
519
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
520
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
521
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -539,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
539
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
540
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
541
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
550
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
542
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
543
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
544
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -559,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
559
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
560
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
561
570
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
571
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
562
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
563
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
564
574
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -579,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
579
589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
580
590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
581
591
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
582
593
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
583
594
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
584
595
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -598,8 +609,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
598
609
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
599
610
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
600
611
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
601
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
602
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
612
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
613
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
614
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
603
615
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
604
616
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
605
617
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@@ -739,6 +751,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
739
751
  case GGML_TYPE_Q8_0:
740
752
  case GGML_TYPE_Q4_0:
741
753
  case GGML_TYPE_Q4_1:
754
+ case GGML_TYPE_Q5_0:
755
+ case GGML_TYPE_Q5_1:
756
+ case GGML_TYPE_IQ4_NL:
742
757
  return true;
743
758
  default:
744
759
  return false;
@@ -832,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
832
847
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
833
848
  struct ggml_tensor * dst = gf->nodes[i];
834
849
 
850
+ if (ggml_is_empty(dst)) {
851
+ continue;
852
+ }
853
+
835
854
  switch (dst->op) {
836
855
  case GGML_OP_NONE:
837
856
  case GGML_OP_RESHAPE:
@@ -1387,6 +1406,14 @@ static enum ggml_status ggml_metal_graph_compute(
1387
1406
  (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1388
1407
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1389
1408
 
1409
+ // some Metal matrix data types require aligned pointers
1410
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1411
+ switch (src0->type) {
1412
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1413
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1414
+ default: break;
1415
+ }
1416
+
1390
1417
  id<MTLComputePipelineState> pipeline = nil;
1391
1418
 
1392
1419
  switch (src0->type) {
@@ -1408,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
1408
1435
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1409
1436
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1410
1437
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1438
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1411
1439
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1412
1440
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1413
1441
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1562,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
1562
1590
  nth1 = 16;
1563
1591
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1564
1592
  } break;
1593
+ case GGML_TYPE_IQ1_M:
1594
+ {
1595
+ nth0 = 4;
1596
+ nth1 = 16;
1597
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1598
+ } break;
1565
1599
  case GGML_TYPE_IQ4_NL:
1566
1600
  {
1567
1601
  nth0 = 4;
@@ -1606,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
1606
1640
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1607
1641
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1608
1642
 
1609
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1610
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1611
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1643
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1644
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1645
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1612
1646
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1613
1647
  }
1614
1648
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1651,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
1651
1685
  {
1652
1686
  //GGML_ASSERT(ne00 == ne10);
1653
1687
  //GGML_ASSERT(ne03 == ne13);
1654
-
1655
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1656
-
1657
- const int n_as = ((int32_t *) dst->op_params)[1];
1658
-
1659
- // TODO: make this more general
1660
- GGML_ASSERT(n_as <= 8);
1688
+ const int n_as = src0->ne[2];
1661
1689
 
1662
1690
  // max size of the src1ids array in the kernel shared buffer
1663
1691
  GGML_ASSERT(ne11 <= 4096);
1664
1692
 
1665
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1666
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1667
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1668
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1693
+ // src2 = ids
1694
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1695
+ const int64_t ne21 = src2->ne[1];
1696
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1698
+
1699
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1700
+ const uint64_t nb21 = src2->nb[1];
1701
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1702
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1669
1703
 
1670
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1671
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1672
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1673
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1704
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1674
1705
 
1675
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1706
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
1676
1707
 
1677
- GGML_ASSERT(!ggml_is_transposed(src2));
1708
+ GGML_ASSERT(!ggml_is_transposed(src0));
1678
1709
  GGML_ASSERT(!ggml_is_transposed(src1));
1679
1710
 
1680
1711
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1681
1712
 
1682
- const uint r2 = ne12/ne22;
1683
- const uint r3 = ne13/ne23;
1684
-
1685
1713
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1686
1714
  // to the matrix-vector kernel
1687
1715
  int ne11_mm_min = n_as;
@@ -1689,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
1689
1717
  const int idx = ((int32_t *) dst->op_params)[0];
1690
1718
 
1691
1719
  // batch size
1692
- GGML_ASSERT(ne01 == ne11);
1720
+ GGML_ASSERT(ne21 == ne11); // ?
1721
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
+ const uint r2 = 1;
1723
+ const uint r3 = 1;
1693
1724
 
1694
1725
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1695
1726
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1698,12 +1729,20 @@ static enum ggml_status ggml_metal_graph_compute(
1698
1729
  // indirect matrix multiplication
1699
1730
  // !!!
1700
1731
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1701
- ne20 % 32 == 0 && ne20 >= 64 &&
1732
+ ne00 % 32 == 0 && ne00 >= 64 &&
1702
1733
  ne11 > ne11_mm_min) {
1703
1734
 
1735
+ // some Metal matrix data types require aligned pointers
1736
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1737
+ switch (src0->type) {
1738
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1739
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1740
+ default: break;
1741
+ }
1742
+
1704
1743
  id<MTLComputePipelineState> pipeline = nil;
1705
1744
 
1706
- switch (src2->type) {
1745
+ switch (src0->type) {
1707
1746
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1708
1747
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1709
1748
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@@ -1722,6 +1761,7 @@ static enum ggml_status ggml_metal_graph_compute(
1722
1761
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1723
1762
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1724
1763
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1764
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1725
1765
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1726
1766
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1727
1767
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1731,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
1731
1771
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1732
1772
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1733
1773
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1734
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1735
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1736
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1737
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1738
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1739
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1740
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1741
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1742
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1743
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1744
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1745
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1746
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1747
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1748
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1749
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1750
- // TODO: how to make this an array? read Metal docs
1751
- for (int j = 0; j < 8; ++j) {
1752
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1753
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1754
-
1755
- size_t offs_src_cur = 0;
1756
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1757
-
1758
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1759
- }
1774
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1760
1791
 
1761
1792
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1762
1793
 
1763
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1794
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1764
1795
  } else {
1765
1796
  int nth0 = 32;
1766
1797
  int nth1 = 1;
@@ -1770,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
1770
1801
  id<MTLComputePipelineState> pipeline = nil;
1771
1802
 
1772
1803
  // use custom matrix x vector kernel
1773
- switch (src2t) {
1804
+ switch (src0t) {
1774
1805
  case GGML_TYPE_F32:
1775
1806
  {
1776
1807
  GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1879,6 +1910,12 @@ static enum ggml_status ggml_metal_graph_compute(
1879
1910
  nth1 = 16;
1880
1911
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1881
1912
  } break;
1913
+ case GGML_TYPE_IQ1_M:
1914
+ {
1915
+ nth0 = 4;
1916
+ nth1 = 16;
1917
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
1918
+ } break;
1882
1919
  case GGML_TYPE_IQ4_NL:
1883
1920
  {
1884
1921
  nth0 = 4;
@@ -1898,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
1898
1935
  }
1899
1936
  };
1900
1937
 
1901
- if (ggml_is_quantized(src2t)) {
1902
- GGML_ASSERT(ne20 >= nth0*nth1);
1938
+ if (ggml_is_quantized(src0t)) {
1939
+ GGML_ASSERT(ne00 >= nth0*nth1);
1903
1940
  }
1904
1941
 
1905
1942
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1908,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
1908
1945
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1909
1946
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1910
1947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1911
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1912
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1913
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1914
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1915
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1916
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1917
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1918
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1919
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1920
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1921
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1922
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1923
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1924
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1925
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1926
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1927
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1928
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1929
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1930
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1931
- // TODO: how to make this an array? read Metal docs
1932
- for (int j = 0; j < 8; ++j) {
1933
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1934
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1935
-
1936
- size_t offs_src_cur = 0;
1937
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1938
-
1939
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1940
- }
1941
-
1942
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1943
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1944
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1945
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1948
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1969
+
1970
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1946
1974
  }
1947
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1948
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1975
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1949
1977
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1950
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1978
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1951
1979
  }
1952
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1953
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1980
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1954
1982
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1955
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1956
1984
  }
1957
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1985
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1958
1986
  const int mem_size = 32*sizeof(float);
1959
1987
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1960
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1988
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1961
1989
  }
1962
- else if (src2t == GGML_TYPE_Q4_K) {
1963
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1990
+ else if (src0t == GGML_TYPE_Q4_K) {
1991
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1964
1992
  }
1965
- else if (src2t == GGML_TYPE_Q3_K) {
1993
+ else if (src0t == GGML_TYPE_Q3_K) {
1966
1994
  #ifdef GGML_QKK_64
1967
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1968
1996
  #else
1969
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1997
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1970
1998
  #endif
1971
1999
  }
1972
- else if (src2t == GGML_TYPE_Q5_K) {
1973
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2000
+ else if (src0t == GGML_TYPE_Q5_K) {
2001
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
2002
  }
1975
- else if (src2t == GGML_TYPE_Q6_K) {
1976
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2003
+ else if (src0t == GGML_TYPE_Q6_K) {
2004
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1977
2005
  } else {
1978
2006
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
1979
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2007
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1980
2008
  }
1981
2009
  }
1982
2010
  } break;
@@ -2003,6 +2031,7 @@ static enum ggml_status ggml_metal_graph_compute(
2003
2031
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2004
2032
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2005
2033
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2034
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2006
2035
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2007
2036
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2008
2037
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
@@ -2382,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
2382
2411
 
2383
2412
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2384
2413
 
2414
+ // bitonic sort requires the number of elements to be power of 2
2415
+ int64_t ne00_padded = 1;
2416
+ while (ne00_padded < ne00) {
2417
+ ne00_padded *= 2;
2418
+ }
2419
+
2420
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2421
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2422
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2423
+
2385
2424
  id<MTLComputePipelineState> pipeline = nil;
2386
2425
 
2387
2426
  switch (order) {
@@ -2391,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
2391
2430
  };
2392
2431
 
2393
2432
  [encoder setComputePipelineState:pipeline];
2394
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2395
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2396
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2433
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2437
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2397
2438
 
2398
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2439
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2399
2440
  } break;
2400
2441
  case GGML_OP_LEAKY_RELU:
2401
2442
  {
@@ -2431,13 +2472,14 @@ static enum ggml_status ggml_metal_graph_compute(
2431
2472
  GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2432
2473
 
2433
2474
  switch (dstt) {
2434
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2435
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2436
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2437
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2438
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2439
- //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2440
- //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2475
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2476
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2477
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
2478
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
2479
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
2480
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
2481
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
2482
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
2441
2483
  default: GGML_ASSERT(false && "not implemented");
2442
2484
  };
2443
2485
  } break;
@@ -2837,6 +2879,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
2837
2879
  /* .graph_plan_compute = */ NULL,
2838
2880
  /* .graph_compute = */ ggml_backend_metal_graph_compute,
2839
2881
  /* .supports_op = */ ggml_backend_metal_supports_op,
2882
+ /* .offload_op = */ NULL,
2840
2883
  /* .event_new = */ NULL,
2841
2884
  /* .event_free = */ NULL,
2842
2885
  /* .event_record = */ NULL,