llama_cpp 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9
8
 
10
9
  #undef MIN
11
10
  #undef MAX
@@ -64,6 +63,7 @@ struct ggml_metal_context {
64
63
  GGML_METAL_DECL_KERNEL(get_rows_f16);
65
64
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
66
65
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
67
67
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
68
68
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
69
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -74,11 +74,21 @@ struct ggml_metal_context {
74
74
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
75
75
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
76
76
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
77
78
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
78
79
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
79
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
80
81
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
81
82
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
82
92
  GGML_METAL_DECL_KERNEL(rope);
83
93
  GGML_METAL_DECL_KERNEL(alibi_f32);
84
94
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +120,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
120
  ctx->n_buffers = 0;
111
121
  ctx->concur_list_len = 0;
112
122
 
113
- // determine if we can use MPS
114
- if (MPSSupportsMTLDevice(ctx->device)) {
115
- fprintf(stderr, "%s: using MPS\n", __func__);
116
- } else {
117
- fprintf(stderr, "%s: not using MPS\n", __func__);
118
- GGML_ASSERT(false && "MPS not supported");
119
- }
120
123
 
121
124
  #if 0
122
125
  // compile from source string and show compile log
@@ -126,7 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
126
129
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
130
  if (error) {
128
131
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
129
- exit(1);
132
+ return NULL;
130
133
  }
131
134
  }
132
135
  #else
@@ -144,7 +147,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
144
147
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
148
  if (error) {
146
149
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
147
- exit(1);
150
+ return NULL;
148
151
  }
149
152
 
150
153
  #ifdef GGML_QKK_64
@@ -156,17 +159,24 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
156
159
  #endif
157
160
  if (error) {
158
161
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
159
- exit(1);
162
+ return NULL;
160
163
  }
161
164
  }
162
165
  #endif
163
166
 
164
167
  // load kernels
165
168
  {
169
+ NSError * error = nil;
166
170
  #define GGML_METAL_ADD_KERNEL(name) \
167
171
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
168
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
169
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
172
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
176
+ if (error) { \
177
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
+ return NULL; \
179
+ }
170
180
 
171
181
  GGML_METAL_ADD_KERNEL(add);
172
182
  GGML_METAL_ADD_KERNEL(add_row);
@@ -181,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
181
191
  GGML_METAL_ADD_KERNEL(get_rows_f16);
182
192
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
183
193
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
194
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
184
195
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
185
196
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
186
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -191,11 +202,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
191
202
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
192
203
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
193
204
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
194
206
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
195
207
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
196
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
197
209
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
198
210
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
212
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
214
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
215
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
216
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
217
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
219
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
199
220
  GGML_METAL_ADD_KERNEL(rope);
200
221
  GGML_METAL_ADD_KERNEL(alibi_f32);
201
222
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -205,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
205
226
  #undef GGML_METAL_ADD_KERNEL
206
227
  }
207
228
 
208
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
209
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
229
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
210
231
  if (ctx->device.maxTransferRate != 0) {
211
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
232
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
212
233
  } else {
213
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
234
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
214
235
  }
215
236
 
216
237
  return ctx;
@@ -224,15 +245,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
224
245
  free(ctx);
225
246
  }
226
247
 
248
+ void * ggml_metal_host_malloc(size_t n) {
249
+ void * data = NULL;
250
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
251
+ if (result != 0) {
252
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
253
+ return NULL;
254
+ }
255
+
256
+ return data;
257
+ }
258
+
259
+ void ggml_metal_host_free(void * data) {
260
+ free(data);
261
+ }
262
+
227
263
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
228
264
  ctx->n_cb = n_cb;
229
265
  }
230
266
 
231
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
232
- if (ctx->concur_list_len) {
233
- return true;
234
- }
235
- return false;
267
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
268
+ return ctx->concur_list_len;
269
+ }
270
+
271
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
272
+ return ctx->concur_list;
236
273
  }
237
274
 
238
275
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +412,7 @@ void ggml_metal_get_tensor(
375
412
 
376
413
  void ggml_metal_graph_find_concurrency(
377
414
  struct ggml_metal_context * ctx,
378
- struct ggml_cgraph * gf) {
415
+ struct ggml_cgraph * gf, bool check_mem) {
379
416
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
380
417
  int nodes_unused[GGML_MAX_CONCUR];
381
418
 
@@ -422,7 +459,7 @@ void ggml_metal_graph_find_concurrency(
422
459
  }
423
460
  }
424
461
  }
425
- if (exe_flag) {
462
+ if (exe_flag && check_mem) {
426
463
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
427
464
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
428
465
  int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,19 +543,15 @@ void ggml_metal_graph_compute(
506
543
 
507
544
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
508
545
 
509
- id<MTLComputeCommandEncoder> encoder = nil;
546
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
510
547
 
511
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
512
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
548
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
513
550
 
514
551
  for (int ind = node_start; ind < node_end; ++ind) {
515
552
  const int i = has_concur ? ctx->concur_list[ind] : ind;
516
553
 
517
554
  if (i == -1) {
518
- if (encoder == nil) {
519
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
520
- continue;
521
- }
522
555
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
523
556
  continue;
524
557
  }
@@ -592,10 +625,6 @@ void ggml_metal_graph_compute(
592
625
  } break;
593
626
  case GGML_OP_ADD:
594
627
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
628
  if (ggml_nelements(src1) == ne10) {
600
629
  // src1 is a row
601
630
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +642,6 @@ void ggml_metal_graph_compute(
613
642
  } break;
614
643
  case GGML_OP_MUL:
615
644
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
645
  if (ggml_nelements(src1) == ne10) {
621
646
  // src1 is a row
622
647
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +659,6 @@ void ggml_metal_graph_compute(
634
659
  } break;
635
660
  case GGML_OP_SCALE:
636
661
  {
637
- if (encoder == nil) {
638
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
639
- }
640
-
641
662
  const float scale = *(const float *) src1->data;
642
663
 
643
664
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +674,6 @@ void ggml_metal_graph_compute(
653
674
  switch (ggml_get_unary_op(gf->nodes[i])) {
654
675
  case GGML_UNARY_OP_SILU:
655
676
  {
656
- if (encoder == nil) {
657
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
658
- }
659
-
660
677
  [encoder setComputePipelineState:ctx->pipeline_silu];
661
678
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
662
679
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +684,6 @@ void ggml_metal_graph_compute(
667
684
  } break;
668
685
  case GGML_UNARY_OP_RELU:
669
686
  {
670
- if (encoder == nil) {
671
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
672
- }
673
-
674
687
  [encoder setComputePipelineState:ctx->pipeline_relu];
675
688
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
689
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +694,6 @@ void ggml_metal_graph_compute(
681
694
  } break;
682
695
  case GGML_UNARY_OP_GELU:
683
696
  {
684
- if (encoder == nil) {
685
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
686
- }
687
-
688
697
  [encoder setComputePipelineState:ctx->pipeline_gelu];
689
698
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
690
699
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +710,6 @@ void ggml_metal_graph_compute(
701
710
  } break;
702
711
  case GGML_OP_SOFT_MAX:
703
712
  {
704
- if (encoder == nil) {
705
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
706
- }
707
-
708
713
  const int nth = 32;
709
714
 
710
715
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +724,6 @@ void ggml_metal_graph_compute(
719
724
  } break;
720
725
  case GGML_OP_DIAG_MASK_INF:
721
726
  {
722
- if (encoder == nil) {
723
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
724
- }
725
-
726
727
  const int n_past = ((int32_t *)(dst->op_params))[0];
727
728
 
728
729
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +741,43 @@ void ggml_metal_graph_compute(
740
741
 
741
742
  GGML_ASSERT(ne00 == ne10);
742
743
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
744
+ uint gqa = ne12/ne02;
743
745
  GGML_ASSERT(ne03 == ne13);
744
746
 
747
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
748
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
745
749
  if (ggml_is_contiguous(src0) &&
746
750
  ggml_is_contiguous(src1) &&
747
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
748
-
749
- if (encoder != nil) {
750
- [encoder endEncoding];
751
- encoder = nil;
752
- }
753
-
754
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
755
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
756
-
757
- // for F32 x F32 we use MPS
758
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
759
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
760
-
761
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
762
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
763
-
764
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
765
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
766
-
767
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
768
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
769
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
770
-
771
- // we need to do ne12 multiplications
772
- // TODO: is there a way to do this in parallel - currently very slow ..
773
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
774
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
775
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
776
- size_t offs_src1_cur = offs_src1 + i02*nb12;
777
- size_t offs_dst_cur = offs_dst + i02*nb2;
778
-
779
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
780
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
781
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
782
-
783
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
751
+ src1t == GGML_TYPE_F32 &&
752
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
753
+ ne00%32 == 0 &&
754
+ ne11 > 1) {
755
+ switch (src0->type) {
756
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
757
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
758
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
760
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
761
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
762
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
763
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
764
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
765
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
784
766
  }
767
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
768
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
769
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
770
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
771
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
772
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
773
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
774
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
775
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
776
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
777
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
778
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
779
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
785
780
  } else {
786
- if (encoder == nil) {
787
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
788
- }
789
-
790
781
  int nth0 = 32;
791
782
  int nth1 = 1;
792
783
 
@@ -816,6 +807,15 @@ void ggml_metal_graph_compute(
816
807
  nth1 = 8;
817
808
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
818
809
  } break;
810
+ case GGML_TYPE_Q8_0:
811
+ {
812
+ GGML_ASSERT(ne02 == 1);
813
+ GGML_ASSERT(ne12 == 1);
814
+
815
+ nth0 = 8;
816
+ nth1 = 8;
817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
818
+ } break;
819
819
  case GGML_TYPE_Q2_K:
820
820
  {
821
821
  GGML_ASSERT(ne02 == 1);
@@ -885,23 +885,24 @@ void ggml_metal_graph_compute(
885
885
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
886
886
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
887
887
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
888
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
889
 
889
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
890
891
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
891
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
893
  }
893
894
  else if (src0t == GGML_TYPE_Q3_K) {
894
895
  #ifdef GGML_QKK_64
895
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
897
  #else
897
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
899
  #endif
899
900
  }
900
901
  else if (src0t == GGML_TYPE_Q5_K) {
901
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
903
  }
903
904
  else if (src0t == GGML_TYPE_Q6_K) {
904
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
906
  } else {
906
907
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
907
908
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,14 +911,11 @@ void ggml_metal_graph_compute(
910
911
  } break;
911
912
  case GGML_OP_GET_ROWS:
912
913
  {
913
- if (encoder == nil) {
914
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
915
- }
916
-
917
914
  switch (src0->type) {
918
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
915
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
919
916
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
920
917
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
918
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
921
919
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
922
920
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
923
921
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -939,10 +937,6 @@ void ggml_metal_graph_compute(
939
937
  } break;
940
938
  case GGML_OP_RMS_NORM:
941
939
  {
942
- if (encoder == nil) {
943
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
944
- }
945
-
946
940
  float eps;
947
941
  memcpy(&eps, dst->op_params, sizeof(float));
948
942
 
@@ -962,20 +956,17 @@ void ggml_metal_graph_compute(
962
956
  } break;
963
957
  case GGML_OP_NORM:
964
958
  {
965
- if (encoder == nil) {
966
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
967
- }
968
-
969
- const float eps = 1e-5f;
959
+ float eps;
960
+ memcpy(&eps, dst->op_params, sizeof(float));
970
961
 
971
962
  const int nth = 256;
972
963
 
973
964
  [encoder setComputePipelineState:ctx->pipeline_norm];
974
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
975
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
976
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
977
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
978
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
965
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
966
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
967
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
968
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
969
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
979
970
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
980
971
 
981
972
  const int64_t nrows = ggml_nrows(src0);
@@ -984,10 +975,6 @@ void ggml_metal_graph_compute(
984
975
  } break;
985
976
  case GGML_OP_ALIBI:
986
977
  {
987
- if (encoder == nil) {
988
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
989
- }
990
-
991
978
  GGML_ASSERT((src0t == GGML_TYPE_F32));
992
979
 
993
980
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1022,15 +1009,13 @@ void ggml_metal_graph_compute(
1022
1009
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1023
1010
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1024
1011
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1012
+
1025
1013
  const int nth = 32;
1014
+
1026
1015
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1027
1016
  } break;
1028
1017
  case GGML_OP_ROPE:
1029
1018
  {
1030
- if (encoder == nil) {
1031
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1032
- }
1033
-
1034
1019
  const int n_past = ((int32_t *) dst->op_params)[0];
1035
1020
  const int n_dims = ((int32_t *) dst->op_params)[1];
1036
1021
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1041,8 +1026,8 @@ void ggml_metal_graph_compute(
1041
1026
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1042
1027
 
1043
1028
  [encoder setComputePipelineState:ctx->pipeline_rope];
1044
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1029
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1030
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1046
1031
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1047
1032
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1048
1033
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1071,10 +1056,6 @@ void ggml_metal_graph_compute(
1071
1056
  case GGML_OP_CPY:
1072
1057
  case GGML_OP_CONT:
1073
1058
  {
1074
- if (encoder == nil) {
1075
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1076
- }
1077
-
1078
1059
  const int nth = 32;
1079
1060
 
1080
1061
  switch (src0t) {
@@ -1097,24 +1078,24 @@ void ggml_metal_graph_compute(
1097
1078
  default: GGML_ASSERT(false && "not implemented");
1098
1079
  }
1099
1080
 
1100
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1101
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1102
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1103
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1104
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1105
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1106
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1107
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1108
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1109
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1110
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1111
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1112
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1113
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1114
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1115
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1116
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1117
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1081
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1082
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1083
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1084
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1085
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1086
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1087
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1088
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1089
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1090
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1091
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1092
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1093
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1094
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1095
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1096
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1097
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1098
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1118
1099
 
1119
1100
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1120
1101
  } break;