llama_cpp 0.3.7 → 0.3.8

Sign up to get free protection for your applications and to get access to all the features.
@@ -8,29 +8,25 @@ extern "C" {
8
8
 
9
9
  #define GGML_CUDA_MAX_DEVICES 16
10
10
 
11
- void ggml_init_cublas(void);
12
- void ggml_cuda_set_tensor_split(const float * tensor_split);
13
-
14
- void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
- bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
- size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
17
- void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
18
-
19
- // TODO: export these with GGML_API
20
- void * ggml_cuda_host_malloc(size_t size);
21
- void ggml_cuda_host_free(void * ptr);
22
-
23
- void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
24
-
25
- void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
- void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
- void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
- void ggml_cuda_set_main_device(int main_device);
30
- void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
31
- void ggml_cuda_set_scratch_size(size_t scratch_size);
32
- void ggml_cuda_free_scratch(void);
33
- bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
11
+ GGML_API void ggml_init_cublas(void);
12
+ GGML_API void * ggml_cuda_host_malloc(size_t size);
13
+ GGML_API void ggml_cuda_host_free(void * ptr);
14
+
15
+ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
+ GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
+ GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
+ GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
19
+ GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
+ GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
+ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
22
+ GGML_API void ggml_cuda_set_main_device(int main_device);
23
+ GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
+ GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
25
+ GGML_API void ggml_cuda_free_scratch(void);
26
+ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
27
+
28
+ GGML_API int ggml_cuda_get_device_count(void);
29
+ GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
34
30
 
35
31
  #ifdef __cplusplus
36
32
  }
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
63
63
 
64
64
  // try to find operations that can be run concurrently in the graph
65
65
  // you should run it again if the topology of your graph changes
66
- void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
67
67
 
68
- // if the graph has been optimized for concurrently dispatch
69
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
68
+ // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
69
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
71
+ // output the concur_list for ggml_alloc
72
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
70
73
 
71
74
  // same as ggml_graph_compute but uses Metal
72
75
  // creates gf->n_threads command buffers in parallel
@@ -5,7 +5,6 @@
5
5
  #import <Foundation/Foundation.h>
6
6
 
7
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9
8
 
10
9
  #undef MIN
11
10
  #undef MAX
@@ -79,6 +78,14 @@ struct ggml_metal_context {
79
78
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
80
79
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
81
80
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
82
89
  GGML_METAL_DECL_KERNEL(rope);
83
90
  GGML_METAL_DECL_KERNEL(alibi_f32);
84
91
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
117
  ctx->n_buffers = 0;
111
118
  ctx->concur_list_len = 0;
112
119
 
113
- // determine if we can use MPS
114
- if (MPSSupportsMTLDevice(ctx->device)) {
115
- fprintf(stderr, "%s: using MPS\n", __func__);
116
- } else {
117
- fprintf(stderr, "%s: not using MPS\n", __func__);
118
- GGML_ASSERT(false && "MPS not supported");
119
- }
120
120
 
121
121
  #if 0
122
122
  // compile from source string and show compile log
@@ -126,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
126
126
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
127
  if (error) {
128
128
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
129
- exit(1);
129
+ return NULL;
130
130
  }
131
131
  }
132
132
  #else
@@ -144,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
144
144
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
145
  if (error) {
146
146
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
147
- exit(1);
147
+ return NULL;
148
148
  }
149
149
 
150
150
  #ifdef GGML_QKK_64
@@ -156,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
156
156
  #endif
157
157
  if (error) {
158
158
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
159
- exit(1);
159
+ return NULL;
160
160
  }
161
161
  }
162
162
  #endif
163
163
 
164
164
  // load kernels
165
165
  {
166
+ NSError * error = nil;
166
167
  #define GGML_METAL_ADD_KERNEL(name) \
167
168
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
168
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
169
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
169
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
+ fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
171
+ if (error) { \
172
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
+ return NULL; \
174
+ }
170
175
 
171
176
  GGML_METAL_ADD_KERNEL(add);
172
177
  GGML_METAL_ADD_KERNEL(add_row);
@@ -196,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
201
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
197
202
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
198
203
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
206
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
211
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
199
212
  GGML_METAL_ADD_KERNEL(rope);
200
213
  GGML_METAL_ADD_KERNEL(alibi_f32);
201
214
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -228,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
228
241
  ctx->n_cb = n_cb;
229
242
  }
230
243
 
231
- bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
232
- if (ctx->concur_list_len) {
233
- return true;
234
- }
235
- return false;
244
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
245
+ return ctx->concur_list_len;
246
+ }
247
+
248
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
249
+ return ctx->concur_list;
236
250
  }
237
251
 
238
252
  // finds the Metal buffer that contains the tensor data on the GPU device
@@ -375,7 +389,7 @@ void ggml_metal_get_tensor(
375
389
 
376
390
  void ggml_metal_graph_find_concurrency(
377
391
  struct ggml_metal_context * ctx,
378
- struct ggml_cgraph * gf) {
392
+ struct ggml_cgraph * gf, bool check_mem) {
379
393
  int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
380
394
  int nodes_unused[GGML_MAX_CONCUR];
381
395
 
@@ -422,7 +436,7 @@ void ggml_metal_graph_find_concurrency(
422
436
  }
423
437
  }
424
438
  }
425
- if (exe_flag) {
439
+ if (exe_flag && check_mem) {
426
440
  // check if nodes[i]'s data will be overwritten by a node before nodes[i].
427
441
  // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
428
442
  int64_t data_start = (int64_t) gf->nodes[i]->data;
@@ -506,7 +520,7 @@ void ggml_metal_graph_compute(
506
520
 
507
521
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
508
522
 
509
- id<MTLComputeCommandEncoder> encoder = nil;
523
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
510
524
 
511
525
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
512
526
  const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
@@ -515,10 +529,6 @@ void ggml_metal_graph_compute(
515
529
  const int i = has_concur ? ctx->concur_list[ind] : ind;
516
530
 
517
531
  if (i == -1) {
518
- if (encoder == nil) {
519
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
520
- continue;
521
- }
522
532
  [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
523
533
  continue;
524
534
  }
@@ -592,10 +602,6 @@ void ggml_metal_graph_compute(
592
602
  } break;
593
603
  case GGML_OP_ADD:
594
604
  {
595
- if (encoder == nil) {
596
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
597
- }
598
-
599
605
  if (ggml_nelements(src1) == ne10) {
600
606
  // src1 is a row
601
607
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -613,10 +619,6 @@ void ggml_metal_graph_compute(
613
619
  } break;
614
620
  case GGML_OP_MUL:
615
621
  {
616
- if (encoder == nil) {
617
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
618
- }
619
-
620
622
  if (ggml_nelements(src1) == ne10) {
621
623
  // src1 is a row
622
624
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -634,10 +636,6 @@ void ggml_metal_graph_compute(
634
636
  } break;
635
637
  case GGML_OP_SCALE:
636
638
  {
637
- if (encoder == nil) {
638
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
639
- }
640
-
641
639
  const float scale = *(const float *) src1->data;
642
640
 
643
641
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -653,10 +651,6 @@ void ggml_metal_graph_compute(
653
651
  switch (ggml_get_unary_op(gf->nodes[i])) {
654
652
  case GGML_UNARY_OP_SILU:
655
653
  {
656
- if (encoder == nil) {
657
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
658
- }
659
-
660
654
  [encoder setComputePipelineState:ctx->pipeline_silu];
661
655
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
662
656
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -667,10 +661,6 @@ void ggml_metal_graph_compute(
667
661
  } break;
668
662
  case GGML_UNARY_OP_RELU:
669
663
  {
670
- if (encoder == nil) {
671
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
672
- }
673
-
674
664
  [encoder setComputePipelineState:ctx->pipeline_relu];
675
665
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
666
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -681,10 +671,6 @@ void ggml_metal_graph_compute(
681
671
  } break;
682
672
  case GGML_UNARY_OP_GELU:
683
673
  {
684
- if (encoder == nil) {
685
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
686
- }
687
-
688
674
  [encoder setComputePipelineState:ctx->pipeline_gelu];
689
675
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
690
676
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -701,10 +687,6 @@ void ggml_metal_graph_compute(
701
687
  } break;
702
688
  case GGML_OP_SOFT_MAX:
703
689
  {
704
- if (encoder == nil) {
705
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
706
- }
707
-
708
690
  const int nth = 32;
709
691
 
710
692
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -719,10 +701,6 @@ void ggml_metal_graph_compute(
719
701
  } break;
720
702
  case GGML_OP_DIAG_MASK_INF:
721
703
  {
722
- if (encoder == nil) {
723
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
724
- }
725
-
726
704
  const int n_past = ((int32_t *)(dst->op_params))[0];
727
705
 
728
706
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
@@ -740,53 +718,43 @@ void ggml_metal_graph_compute(
740
718
 
741
719
  GGML_ASSERT(ne00 == ne10);
742
720
  // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
721
+ uint gqa = ne12/ne02;
743
722
  GGML_ASSERT(ne03 == ne13);
744
723
 
724
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
725
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
745
726
  if (ggml_is_contiguous(src0) &&
746
727
  ggml_is_contiguous(src1) &&
747
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
748
-
749
- if (encoder != nil) {
750
- [encoder endEncoding];
751
- encoder = nil;
752
- }
753
-
754
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
755
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
756
-
757
- // for F32 x F32 we use MPS
758
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
759
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
760
-
761
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
762
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
763
-
764
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
765
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
766
-
767
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
768
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
769
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
770
-
771
- // we need to do ne12 multiplications
772
- // TODO: is there a way to do this in parallel - currently very slow ..
773
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
774
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
775
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
776
- size_t offs_src1_cur = offs_src1 + i02*nb12;
777
- size_t offs_dst_cur = offs_dst + i02*nb2;
778
-
779
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
780
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
781
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
782
-
783
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
784
- }
785
- } else {
786
- if (encoder == nil) {
787
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
728
+ src1t == GGML_TYPE_F32 &&
729
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
+ ne00%32 == 0 &&
731
+ ne11 > 1) {
732
+ switch (src0->type) {
733
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
+ }
743
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
788
756
  }
789
-
757
+ else {
790
758
  int nth0 = 32;
791
759
  int nth1 = 1;
792
760
 
@@ -885,23 +853,24 @@ void ggml_metal_graph_compute(
885
853
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
886
854
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
887
855
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
857
 
889
858
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
859
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
891
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
860
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
861
  }
893
862
  else if (src0t == GGML_TYPE_Q3_K) {
894
863
  #ifdef GGML_QKK_64
895
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
864
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
865
  #else
897
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
866
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
867
  #endif
899
868
  }
900
869
  else if (src0t == GGML_TYPE_Q5_K) {
901
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
871
  }
903
872
  else if (src0t == GGML_TYPE_Q6_K) {
904
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
874
  } else {
906
875
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
907
876
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -910,10 +879,6 @@ void ggml_metal_graph_compute(
910
879
  } break;
911
880
  case GGML_OP_GET_ROWS:
912
881
  {
913
- if (encoder == nil) {
914
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
915
- }
916
-
917
882
  switch (src0->type) {
918
883
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
919
884
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
@@ -939,10 +904,6 @@ void ggml_metal_graph_compute(
939
904
  } break;
940
905
  case GGML_OP_RMS_NORM:
941
906
  {
942
- if (encoder == nil) {
943
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
944
- }
945
-
946
907
  float eps;
947
908
  memcpy(&eps, dst->op_params, sizeof(float));
948
909
 
@@ -962,10 +923,6 @@ void ggml_metal_graph_compute(
962
923
  } break;
963
924
  case GGML_OP_NORM:
964
925
  {
965
- if (encoder == nil) {
966
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
967
- }
968
-
969
926
  const float eps = 1e-5f;
970
927
 
971
928
  const int nth = 256;
@@ -984,10 +941,6 @@ void ggml_metal_graph_compute(
984
941
  } break;
985
942
  case GGML_OP_ALIBI:
986
943
  {
987
- if (encoder == nil) {
988
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
989
- }
990
-
991
944
  GGML_ASSERT((src0t == GGML_TYPE_F32));
992
945
 
993
946
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
@@ -1027,10 +980,6 @@ void ggml_metal_graph_compute(
1027
980
  } break;
1028
981
  case GGML_OP_ROPE:
1029
982
  {
1030
- if (encoder == nil) {
1031
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1032
- }
1033
-
1034
983
  const int n_past = ((int32_t *) dst->op_params)[0];
1035
984
  const int n_dims = ((int32_t *) dst->op_params)[1];
1036
985
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1071,10 +1020,6 @@ void ggml_metal_graph_compute(
1071
1020
  case GGML_OP_CPY:
1072
1021
  case GGML_OP_CONT:
1073
1022
  {
1074
- if (encoder == nil) {
1075
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
1076
- }
1077
-
1078
1023
  const int nth = 32;
1079
1024
 
1080
1025
  switch (src0t) {