llama_cpp 0.3.7 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,29 +8,25 @@ extern "C" {
8
8
 
9
9
  #define GGML_CUDA_MAX_DEVICES 16
10
10
 
11
- void ggml_init_cublas(void);
12
- void ggml_cuda_set_tensor_split(const float * tensor_split);
13
-
14
- void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
- bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
- size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
17
- void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
18
-
19
- // TODO: export these with GGML_API
20
- void * ggml_cuda_host_malloc(size_t size);
21
- void ggml_cuda_host_free(void * ptr);
22
-
23
- void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
24
-
25
- void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
- void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
- void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
- void ggml_cuda_set_main_device(int main_device);
30
- void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
31
- void ggml_cuda_set_scratch_size(size_t scratch_size);
32
- void ggml_cuda_free_scratch(void);
33
- bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
11
+ GGML_API void ggml_init_cublas(void);
12
+ GGML_API void * ggml_cuda_host_malloc(size_t size);
13
+ GGML_API void ggml_cuda_host_free(void * ptr);
14
+
15
+ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
+ GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
+ GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
+ GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19
+ GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
+ GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
+ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
22
+ GGML_API void ggml_cuda_set_main_device(int main_device);
23
+ GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
+ GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
25
+ GGML_API void ggml_cuda_free_scratch(void);
26
+ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
27
+
28
+ GGML_API int ggml_cuda_get_device_count(void);
29
+ GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
30
 
35
31
  #ifdef __cplusplus
36
32
  }
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
63
63
 
64
64
  // try to find operations that can be run concurrently in the graph
65
65
  // you should run it again if the topology of your graph changes
66
- void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
67
67
 
68
- // if the graph has been optimized for concurrently dispatch
69
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
68
+ // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
69
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
71
+ // output the concur_list for ggml_alloc
72
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
70
73
 
71
74
  // same as ggml_graph_compute but uses Metal
72
75
  // creates gf->n_threads command buffers in parallel
@@ -5,7 +5,6 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9
8
 
10
9
  #undef MIN
11
10
  #undef MAX
@@ -79,6 +78,14 @@ struct ggml_metal_context {
79
78
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
80
79
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
81
80
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
82
89
  GGML_METAL_DECL_KERNEL(rope);
83
90
  GGML_METAL_DECL_KERNEL(alibi_f32);
84
91
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
117
  ctx->n_buffers = 0;
111
118
  ctx->concur_list_len = 0;
112
119
 
113
- // determine if we can use MPS
114
- if (MPSSupportsMTLDevice(ctx->device)) {
115
- fprintf(stderr, "%s: using MPS\n", __func__);
116
- } else {
117
- fprintf(stderr, "%s: not using MPS\n", __func__);
118
- GGML_ASSERT(false && "MPS not supported");
119
- }
120
120
 
121
121
  #if 0
122
122
  // compile from source string and show compile log
@@ -126,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
126
126
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
127
  if (error) {
128
128
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
129
- exit(1);
129
+ return NULL;
130
130
  }
131
131
  }
132
132
  #else
@@ -144,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
144
144
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
145
  if (error) {
146
146
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
147
- exit(1);
147
+ return NULL;
148
148
  }
149
149
 
150
150
  #ifdef GGML_QKK_64
@@ -156,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
156
156
  #endif
157
157
  if (error) {
158
158
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
159
- exit(1);
159
+ return NULL;
160
160
  }
161
161
  }
162
162
  #endif
163
163
 
164
164
  // load kernels
165
165
  {
166
+ NSError * error = nil;
166
167
  #define GGML_METAL_ADD_KERNEL(name) \
167
168
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
168
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
169
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
169
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
+ fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
171
+ if (error) { \
172
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
+ return NULL; \
174
+ }
170
175
 
171
176
  GGML_METAL_ADD_KERNEL(add);
172
177
  GGML_METAL_ADD_KERNEL(add_row);
@@ -196,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
201
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
197
202
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
198
203
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
206
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
199
212
  GGML_METAL_ADD_KERNEL(rope);
200
213
  GGML_METAL_ADD_KERNEL(alibi_f32);
201
214
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -228,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
228
241
  ctx->n_cb = n_cb;
229
242
  }
230
243
 
231
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
232
- if (ctx->concur_list_len) {
233
- return true;
234
- }
235
- return false;
244
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
245
+ return ctx->concur_list_len;
246
+ }
247
+
248
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
249
+ return ctx->concur_list;
236
250
  }
237
251
 
238
252
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +389,7 @@ void ggml_metal_get_tensor(
375
389
 
376
390
  void ggml_metal_graph_find_concurrency(
377
391
  struct ggml_metal_context * ctx,
378
- struct ggml_cgraph * gf) {
392
+ struct ggml_cgraph * gf, bool check_mem) {
379
393
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
380
394
  int nodes_unused[GGML_MAX_CONCUR];
381
395
 
@@ -422,7 +436,7 @@ void ggml_metal_graph_find_concurrency(
422
436
  }
423
437
  }
424
438
  }
425
- if (exe_flag) {
439
+ if (exe_flag && check_mem) {
426
440
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
427
441
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
428
442
  int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,7 +520,7 @@ void ggml_metal_graph_compute(
506
520
 
507
521
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
508
522
 
509
- id<MTLComputeCommandEncoder> encoder = nil;
523
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
510
524
 
511
525
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
512
526
  const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -515,10 +529,6 @@ void ggml_metal_graph_compute(
515
529
  const int i = has_concur ? ctx->concur_list[ind] : ind;
516
530
 
517
531
  if (i == -1) {
518
- if (encoder == nil) {
519
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
520
- continue;
521
- }
522
532
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
523
533
  continue;
524
534
  }
@@ -592,10 +602,6 @@ void ggml_metal_graph_compute(
592
602
  } break;
593
603
  case GGML_OP_ADD:
594
604
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
605
  if (ggml_nelements(src1) == ne10) {
600
606
  // src1 is a row
601
607
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +619,6 @@ void ggml_metal_graph_compute(
613
619
  } break;
614
620
  case GGML_OP_MUL:
615
621
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
622
  if (ggml_nelements(src1) == ne10) {
621
623
  // src1 is a row
622
624
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +636,6 @@ void ggml_metal_graph_compute(
634
636
  } break;
635
637
  case GGML_OP_SCALE:
636
638
  {
637
- if (encoder == nil) {
638
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
639
- }
640
-
641
639
  const float scale = *(const float *) src1->data;
642
640
 
643
641
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +651,6 @@ void ggml_metal_graph_compute(
653
651
  switch (ggml_get_unary_op(gf->nodes[i])) {
654
652
  case GGML_UNARY_OP_SILU:
655
653
  {
656
- if (encoder == nil) {
657
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
658
- }
659
-
660
654
  [encoder setComputePipelineState:ctx->pipeline_silu];
661
655
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
662
656
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +661,6 @@ void ggml_metal_graph_compute(
667
661
  } break;
668
662
  case GGML_UNARY_OP_RELU:
669
663
  {
670
- if (encoder == nil) {
671
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
672
- }
673
-
674
664
  [encoder setComputePipelineState:ctx->pipeline_relu];
675
665
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
666
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +671,6 @@ void ggml_metal_graph_compute(
681
671
  } break;
682
672
  case GGML_UNARY_OP_GELU:
683
673
  {
684
- if (encoder == nil) {
685
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
686
- }
687
-
688
674
  [encoder setComputePipelineState:ctx->pipeline_gelu];
689
675
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
690
676
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +687,6 @@ void ggml_metal_graph_compute(
701
687
  } break;
702
688
  case GGML_OP_SOFT_MAX:
703
689
  {
704
- if (encoder == nil) {
705
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
706
- }
707
-
708
690
  const int nth = 32;
709
691
 
710
692
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +701,6 @@ void ggml_metal_graph_compute(
719
701
  } break;
720
702
  case GGML_OP_DIAG_MASK_INF:
721
703
  {
722
- if (encoder == nil) {
723
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
724
- }
725
-
726
704
  const int n_past = ((int32_t *)(dst->op_params))[0];
727
705
 
728
706
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +718,43 @@ void ggml_metal_graph_compute(
740
718
 
741
719
  GGML_ASSERT(ne00 == ne10);
742
720
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
721
+ uint gqa = ne12/ne02;
743
722
  GGML_ASSERT(ne03 == ne13);
744
723
 
724
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
725
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
745
726
  if (ggml_is_contiguous(src0) &&
746
727
  ggml_is_contiguous(src1) &&
747
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
748
-
749
- if (encoder != nil) {
750
- [encoder endEncoding];
751
- encoder = nil;
752
- }
753
-
754
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
755
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
756
-
757
- // for F32 x F32 we use MPS
758
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
759
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
760
-
761
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
762
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
763
-
764
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
765
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
766
-
767
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
768
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
769
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
770
-
771
- // we need to do ne12 multiplications
772
- // TODO: is there a way to do this in parallel - currently very slow ..
773
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
774
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
775
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
776
- size_t offs_src1_cur = offs_src1 + i02*nb12;
777
- size_t offs_dst_cur = offs_dst + i02*nb2;
778
-
779
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
780
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
781
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
782
-
783
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
784
- }
785
- } else {
786
- if (encoder == nil) {
787
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
728
+ src1t == GGML_TYPE_F32 &&
729
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
+ ne00%32 == 0 &&
731
+ ne11 > 1) {
732
+ switch (src0->type) {
733
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
+ }
743
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
788
756
  }
789
-
757
+ else {
790
758
  int nth0 = 32;
791
759
  int nth1 = 1;
792
760
 
@@ -885,23 +853,24 @@ void ggml_metal_graph_compute(
885
853
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
886
854
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
887
855
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
857
 
889
858
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
859
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
891
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
861
  }
893
862
  else if (src0t == GGML_TYPE_Q3_K) {
894
863
  #ifdef GGML_QKK_64
895
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
864
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
865
  #else
897
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
866
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
867
  #endif
899
868
  }
900
869
  else if (src0t == GGML_TYPE_Q5_K) {
901
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
871
  }
903
872
  else if (src0t == GGML_TYPE_Q6_K) {
904
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
874
  } else {
906
875
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
907
876
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,10 +879,6 @@ void ggml_metal_graph_compute(
910
879
  } break;
911
880
  case GGML_OP_GET_ROWS:
912
881
  {
913
- if (encoder == nil) {
914
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
915
- }
916
-
917
882
  switch (src0->type) {
918
883
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
919
884
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -939,10 +904,6 @@ void ggml_metal_graph_compute(
939
904
  } break;
940
905
  case GGML_OP_RMS_NORM:
941
906
  {
942
- if (encoder == nil) {
943
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
944
- }
945
-
946
907
  float eps;
947
908
  memcpy(&eps, dst->op_params, sizeof(float));
948
909
 
@@ -962,10 +923,6 @@ void ggml_metal_graph_compute(
962
923
  } break;
963
924
  case GGML_OP_NORM:
964
925
  {
965
- if (encoder == nil) {
966
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
967
- }
968
-
969
926
  const float eps = 1e-5f;
970
927
 
971
928
  const int nth = 256;
@@ -984,10 +941,6 @@ void ggml_metal_graph_compute(
984
941
  } break;
985
942
  case GGML_OP_ALIBI:
986
943
  {
987
- if (encoder == nil) {
988
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
989
- }
990
-
991
944
  GGML_ASSERT((src0t == GGML_TYPE_F32));
992
945
 
993
946
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1027,10 +980,6 @@ void ggml_metal_graph_compute(
1027
980
  } break;
1028
981
  case GGML_OP_ROPE:
1029
982
  {
1030
- if (encoder == nil) {
1031
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1032
- }
1033
-
1034
983
  const int n_past = ((int32_t *) dst->op_params)[0];
1035
984
  const int n_dims = ((int32_t *) dst->op_params)[1];
1036
985
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1071,10 +1020,6 @@ void ggml_metal_graph_compute(
1071
1020
  case GGML_OP_CPY:
1072
1021
  case GGML_OP_CONT:
1073
1022
  {
1074
- if (encoder == nil) {
1075
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1076
- }
1077
-
1078
1023
  const int nth = 32;
1079
1024
 
1080
1025
  switch (src0t) {