llama_cpp 0.3.7 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -5,7 +5,6 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9
8
 
10
9
  #undef MIN
11
10
  #undef MAX
@@ -64,6 +63,7 @@ struct ggml_metal_context {
64
63
  GGML_METAL_DECL_KERNEL(get_rows_f16);
65
64
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
66
65
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
67
67
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
68
68
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
69
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -74,11 +74,21 @@ struct ggml_metal_context {
74
74
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
75
75
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
76
76
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
77
78
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
78
79
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
79
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
80
81
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
81
82
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
82
92
  GGML_METAL_DECL_KERNEL(rope);
83
93
  GGML_METAL_DECL_KERNEL(alibi_f32);
84
94
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +120,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
120
  ctx->n_buffers = 0;
111
121
  ctx->concur_list_len = 0;
112
122
 
113
- // determine if we can use MPS
114
- if (MPSSupportsMTLDevice(ctx->device)) {
115
- fprintf(stderr, "%s: using MPS\n", __func__);
116
- } else {
117
- fprintf(stderr, "%s: not using MPS\n", __func__);
118
- GGML_ASSERT(false && "MPS not supported");
119
- }
120
123
 
121
124
  #if 0
122
125
  // compile from source string and show compile log
@@ -126,7 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
126
129
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
130
  if (error) {
128
131
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
129
- exit(1);
132
+ return NULL;
130
133
  }
131
134
  }
132
135
  #else
@@ -144,7 +147,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
144
147
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
148
  if (error) {
146
149
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
147
- exit(1);
150
+ return NULL;
148
151
  }
149
152
 
150
153
  #ifdef GGML_QKK_64
@@ -156,17 +159,24 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
156
159
  #endif
157
160
  if (error) {
158
161
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
159
- exit(1);
162
+ return NULL;
160
163
  }
161
164
  }
162
165
  #endif
163
166
 
164
167
  // load kernels
165
168
  {
169
+ NSError * error = nil;
166
170
  #define GGML_METAL_ADD_KERNEL(name) \
167
171
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
168
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
169
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
172
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
176
+ if (error) { \
177
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
+ return NULL; \
179
+ }
170
180
 
171
181
  GGML_METAL_ADD_KERNEL(add);
172
182
  GGML_METAL_ADD_KERNEL(add_row);
@@ -181,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
181
191
  GGML_METAL_ADD_KERNEL(get_rows_f16);
182
192
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
183
193
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
194
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
184
195
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
185
196
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
186
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -191,11 +202,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
191
202
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
192
203
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
193
204
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
194
206
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
195
207
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
196
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
197
209
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
198
210
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
212
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
214
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
215
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
216
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
217
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
219
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
199
220
  GGML_METAL_ADD_KERNEL(rope);
200
221
  GGML_METAL_ADD_KERNEL(alibi_f32);
201
222
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -205,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
205
226
  #undef GGML_METAL_ADD_KERNEL
206
227
  }
207
228
 
208
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
209
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
229
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
210
231
  if (ctx->device.maxTransferRate != 0) {
211
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
232
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
212
233
  } else {
213
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
234
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
214
235
  }
215
236
 
216
237
  return ctx;
@@ -224,15 +245,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
224
245
  free(ctx);
225
246
  }
226
247
 
248
+ void * ggml_metal_host_malloc(size_t n) {
249
+ void * data = NULL;
250
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
251
+ if (result != 0) {
252
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
253
+ return NULL;
254
+ }
255
+
256
+ return data;
257
+ }
258
+
259
+ void ggml_metal_host_free(void * data) {
260
+ free(data);
261
+ }
262
+
227
263
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
228
264
  ctx->n_cb = n_cb;
229
265
  }
230
266
 
231
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
232
- if (ctx->concur_list_len) {
233
- return true;
234
- }
235
- return false;
267
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
268
+ return ctx->concur_list_len;
269
+ }
270
+
271
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
272
+ return ctx->concur_list;
236
273
  }
237
274
 
238
275
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +412,7 @@ void ggml_metal_get_tensor(
375
412
 
376
413
  void ggml_metal_graph_find_concurrency(
377
414
  struct ggml_metal_context * ctx,
378
- struct ggml_cgraph * gf) {
415
+ struct ggml_cgraph * gf, bool check_mem) {
379
416
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
380
417
  int nodes_unused[GGML_MAX_CONCUR];
381
418
 
@@ -422,7 +459,7 @@ void ggml_metal_graph_find_concurrency(
422
459
  }
423
460
  }
424
461
  }
425
- if (exe_flag) {
462
+ if (exe_flag && check_mem) {
426
463
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
427
464
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
428
465
  int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,19 +543,15 @@ void ggml_metal_graph_compute(
506
543
 
507
544
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
508
545
 
509
- id<MTLComputeCommandEncoder> encoder = nil;
546
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
510
547
 
511
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
512
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
548
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
513
550
 
514
551
  for (int ind = node_start; ind < node_end; ++ind) {
515
552
  const int i = has_concur ? ctx->concur_list[ind] : ind;
516
553
 
517
554
  if (i == -1) {
518
- if (encoder == nil) {
519
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
520
- continue;
521
- }
522
555
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
523
556
  continue;
524
557
  }
@@ -592,10 +625,6 @@ void ggml_metal_graph_compute(
592
625
  } break;
593
626
  case GGML_OP_ADD:
594
627
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
628
  if (ggml_nelements(src1) == ne10) {
600
629
  // src1 is a row
601
630
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +642,6 @@ void ggml_metal_graph_compute(
613
642
  } break;
614
643
  case GGML_OP_MUL:
615
644
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
645
  if (ggml_nelements(src1) == ne10) {
621
646
  // src1 is a row
622
647
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +659,6 @@ void ggml_metal_graph_compute(
634
659
  } break;
635
660
  case GGML_OP_SCALE:
636
661
  {
637
- if (encoder == nil) {
638
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
639
- }
640
-
641
662
  const float scale = *(const float *) src1->data;
642
663
 
643
664
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +674,6 @@ void ggml_metal_graph_compute(
653
674
  switch (ggml_get_unary_op(gf->nodes[i])) {
654
675
  case GGML_UNARY_OP_SILU:
655
676
  {
656
- if (encoder == nil) {
657
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
658
- }
659
-
660
677
  [encoder setComputePipelineState:ctx->pipeline_silu];
661
678
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
662
679
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +684,6 @@ void ggml_metal_graph_compute(
667
684
  } break;
668
685
  case GGML_UNARY_OP_RELU:
669
686
  {
670
- if (encoder == nil) {
671
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
672
- }
673
-
674
687
  [encoder setComputePipelineState:ctx->pipeline_relu];
675
688
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
689
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +694,6 @@ void ggml_metal_graph_compute(
681
694
  } break;
682
695
  case GGML_UNARY_OP_GELU:
683
696
  {
684
- if (encoder == nil) {
685
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
686
- }
687
-
688
697
  [encoder setComputePipelineState:ctx->pipeline_gelu];
689
698
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
690
699
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +710,6 @@ void ggml_metal_graph_compute(
701
710
  } break;
702
711
  case GGML_OP_SOFT_MAX:
703
712
  {
704
- if (encoder == nil) {
705
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
706
- }
707
-
708
713
  const int nth = 32;
709
714
 
710
715
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +724,6 @@ void ggml_metal_graph_compute(
719
724
  } break;
720
725
  case GGML_OP_DIAG_MASK_INF:
721
726
  {
722
- if (encoder == nil) {
723
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
724
- }
725
-
726
727
  const int n_past = ((int32_t *)(dst->op_params))[0];
727
728
 
728
729
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +741,43 @@ void ggml_metal_graph_compute(
740
741
 
741
742
  GGML_ASSERT(ne00 == ne10);
742
743
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
744
+ uint gqa = ne12/ne02;
743
745
  GGML_ASSERT(ne03 == ne13);
744
746
 
747
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
748
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
745
749
  if (ggml_is_contiguous(src0) &&
746
750
  ggml_is_contiguous(src1) &&
747
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
748
-
749
- if (encoder != nil) {
750
- [encoder endEncoding];
751
- encoder = nil;
752
- }
753
-
754
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
755
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
756
-
757
- // for F32 x F32 we use MPS
758
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
759
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
760
-
761
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
762
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
763
-
764
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
765
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
766
-
767
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
768
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
769
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
770
-
771
- // we need to do ne12 multiplications
772
- // TODO: is there a way to do this in parallel - currently very slow ..
773
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
774
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
775
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
776
- size_t offs_src1_cur = offs_src1 + i02*nb12;
777
- size_t offs_dst_cur = offs_dst + i02*nb2;
778
-
779
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
780
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
781
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
782
-
783
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
751
+ src1t == GGML_TYPE_F32 &&
752
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
753
+ ne00%32 == 0 &&
754
+ ne11 > 1) {
755
+ switch (src0->type) {
756
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
757
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
758
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
760
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
761
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
762
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
763
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
764
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
765
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
784
766
  }
767
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
768
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
769
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
770
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
771
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
772
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
773
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
774
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
775
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
776
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
777
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
778
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
779
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
785
780
  } else {
786
- if (encoder == nil) {
787
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
788
- }
789
-
790
781
  int nth0 = 32;
791
782
  int nth1 = 1;
792
783
 
@@ -816,6 +807,15 @@ void ggml_metal_graph_compute(
816
807
  nth1 = 8;
817
808
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
818
809
  } break;
810
+ case GGML_TYPE_Q8_0:
811
+ {
812
+ GGML_ASSERT(ne02 == 1);
813
+ GGML_ASSERT(ne12 == 1);
814
+
815
+ nth0 = 8;
816
+ nth1 = 8;
817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
818
+ } break;
819
819
  case GGML_TYPE_Q2_K:
820
820
  {
821
821
  GGML_ASSERT(ne02 == 1);
@@ -885,23 +885,24 @@ void ggml_metal_graph_compute(
885
885
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
886
886
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
887
887
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
888
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
889
 
889
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
890
891
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
891
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
893
  }
893
894
  else if (src0t == GGML_TYPE_Q3_K) {
894
895
  #ifdef GGML_QKK_64
895
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
897
  #else
897
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
899
  #endif
899
900
  }
900
901
  else if (src0t == GGML_TYPE_Q5_K) {
901
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
903
  }
903
904
  else if (src0t == GGML_TYPE_Q6_K) {
904
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
906
  } else {
906
907
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
907
908
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,14 +911,11 @@ void ggml_metal_graph_compute(
910
911
  } break;
911
912
  case GGML_OP_GET_ROWS:
912
913
  {
913
- if (encoder == nil) {
914
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
915
- }
916
-
917
914
  switch (src0->type) {
918
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
915
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
919
916
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
920
917
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
918
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
921
919
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
922
920
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
923
921
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -939,10 +937,6 @@ void ggml_metal_graph_compute(
939
937
  } break;
940
938
  case GGML_OP_RMS_NORM:
941
939
  {
942
- if (encoder == nil) {
943
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
944
- }
945
-
946
940
  float eps;
947
941
  memcpy(&eps, dst->op_params, sizeof(float));
948
942
 
@@ -962,20 +956,17 @@ void ggml_metal_graph_compute(
962
956
  } break;
963
957
  case GGML_OP_NORM:
964
958
  {
965
- if (encoder == nil) {
966
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
967
- }
968
-
969
- const float eps = 1e-5f;
959
+ float eps;
960
+ memcpy(&eps, dst->op_params, sizeof(float));
970
961
 
971
962
  const int nth = 256;
972
963
 
973
964
  [encoder setComputePipelineState:ctx->pipeline_norm];
974
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
975
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
976
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
977
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
978
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
965
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
966
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
967
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
968
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
969
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
979
970
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
980
971
 
981
972
  const int64_t nrows = ggml_nrows(src0);
@@ -984,10 +975,6 @@ void ggml_metal_graph_compute(
984
975
  } break;
985
976
  case GGML_OP_ALIBI:
986
977
  {
987
- if (encoder == nil) {
988
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
989
- }
990
-
991
978
  GGML_ASSERT((src0t == GGML_TYPE_F32));
992
979
 
993
980
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1022,15 +1009,13 @@ void ggml_metal_graph_compute(
1022
1009
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1023
1010
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1024
1011
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1012
+
1025
1013
  const int nth = 32;
1014
+
1026
1015
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1027
1016
  } break;
1028
1017
  case GGML_OP_ROPE:
1029
1018
  {
1030
- if (encoder == nil) {
1031
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1032
- }
1033
-
1034
1019
  const int n_past = ((int32_t *) dst->op_params)[0];
1035
1020
  const int n_dims = ((int32_t *) dst->op_params)[1];
1036
1021
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1041,8 +1026,8 @@ void ggml_metal_graph_compute(
1041
1026
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1042
1027
 
1043
1028
  [encoder setComputePipelineState:ctx->pipeline_rope];
1044
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1029
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1030
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1046
1031
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1047
1032
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1048
1033
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1071,10 +1056,6 @@ void ggml_metal_graph_compute(
1071
1056
  case GGML_OP_CPY:
1072
1057
  case GGML_OP_CONT:
1073
1058
  {
1074
- if (encoder == nil) {
1075
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1076
- }
1077
-
1078
1059
  const int nth = 32;
1079
1060
 
1080
1061
  switch (src0t) {
@@ -1097,24 +1078,24 @@ void ggml_metal_graph_compute(
1097
1078
  default: GGML_ASSERT(false && "not implemented");
1098
1079
  }
1099
1080
 
1100
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1101
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1102
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1103
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1104
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1105
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1106
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1107
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1108
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1109
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1110
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1111
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1112
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1113
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1114
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1115
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1116
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1117
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1081
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1082
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1083
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1084
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1085
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1086
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1087
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1088
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1089
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1090
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1091
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1092
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1093
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1094
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1095
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1096
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1097
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1098
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1118
1099
 
1119
1100
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1120
1101
  } break;