llama_cpp 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,7 +63,10 @@ struct ggml_metal_context {
63
63
  GGML_METAL_DECL_KERNEL(relu);
64
64
  GGML_METAL_DECL_KERNEL(gelu);
65
65
  GGML_METAL_DECL_KERNEL(soft_max);
66
+ GGML_METAL_DECL_KERNEL(soft_max_4);
66
67
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
68
+ GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
69
+ GGML_METAL_DECL_KERNEL(get_rows_f32);
67
70
  GGML_METAL_DECL_KERNEL(get_rows_f16);
68
71
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
69
72
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -75,8 +78,10 @@ struct ggml_metal_context {
75
78
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
76
79
  GGML_METAL_DECL_KERNEL(rms_norm);
77
80
  GGML_METAL_DECL_KERNEL(norm);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
78
82
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79
83
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
80
85
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
81
86
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
82
87
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -85,6 +90,7 @@ struct ggml_metal_context {
85
90
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
86
91
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
87
92
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
88
94
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
89
95
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
90
96
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
117
123
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
118
124
  metal_printf("%s: allocating\n", __func__);
119
125
 
120
- // Show all the Metal device instances in the system
121
- NSArray * devices = MTLCopyAllDevices();
122
126
  id <MTLDevice> device;
123
127
  NSString * s;
128
+
129
+ #if TARGET_OS_OSX
130
+ // Show all the Metal device instances in the system
131
+ NSArray * devices = MTLCopyAllDevices();
124
132
  for (device in devices) {
125
133
  s = [device name];
126
134
  metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
127
135
  }
136
+ #endif
128
137
 
129
138
  // Pick and show default Metal device
130
139
  device = MTLCreateSystemDefaultDevice();
@@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
139
148
  ctx->n_buffers = 0;
140
149
  ctx->concur_list_len = 0;
141
150
 
142
- ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
151
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
143
152
 
144
- #if 0
145
- // compile from source string and show compile log
153
+ #ifdef GGML_SWIFT
154
+ // load the default.metallib file
146
155
  {
147
156
  NSError * error = nil;
148
157
 
149
- ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
158
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
159
+ NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
160
+ NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
161
+ NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
162
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
163
+
164
+ // Load the metallib file into a Metal library
165
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
+
150
167
  if (error) {
151
168
  metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
152
169
  return NULL;
@@ -161,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
161
178
 
162
179
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
163
180
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
164
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
181
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
165
182
  metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
166
183
 
167
184
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
207
224
  GGML_METAL_ADD_KERNEL(relu);
208
225
  GGML_METAL_ADD_KERNEL(gelu);
209
226
  GGML_METAL_ADD_KERNEL(soft_max);
227
+ GGML_METAL_ADD_KERNEL(soft_max_4);
210
228
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
229
+ GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
230
+ GGML_METAL_ADD_KERNEL(get_rows_f32);
211
231
  GGML_METAL_ADD_KERNEL(get_rows_f16);
212
232
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
213
233
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
219
239
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
220
240
  GGML_METAL_ADD_KERNEL(rms_norm);
221
241
  GGML_METAL_ADD_KERNEL(norm);
242
+ GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
222
243
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223
244
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
224
246
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
225
247
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
226
248
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
229
251
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
230
252
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
231
253
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
232
255
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
233
256
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
234
257
  GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
247
270
  #undef GGML_METAL_ADD_KERNEL
248
271
  }
249
272
 
250
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
251
273
  metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
+ #if TARGET_OS_OSX
275
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
252
276
  if (ctx->device.maxTransferRate != 0) {
253
277
  metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
254
278
  } else {
255
279
  metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
256
280
  }
281
+ #endif
257
282
 
258
283
  return ctx;
259
284
  }
@@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
273
298
  GGML_METAL_DEL_KERNEL(relu);
274
299
  GGML_METAL_DEL_KERNEL(gelu);
275
300
  GGML_METAL_DEL_KERNEL(soft_max);
301
+ GGML_METAL_DEL_KERNEL(soft_max_4);
276
302
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
303
+ GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
304
+ GGML_METAL_DEL_KERNEL(get_rows_f32);
277
305
  GGML_METAL_DEL_KERNEL(get_rows_f16);
278
306
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
279
307
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
285
313
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
286
314
  GGML_METAL_DEL_KERNEL(rms_norm);
287
315
  GGML_METAL_DEL_KERNEL(norm);
316
+ GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
288
317
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289
318
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
319
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
290
320
  GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291
321
  GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292
322
  GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
295
325
  GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
296
326
  GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
297
327
  GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
328
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
298
329
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
299
330
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
300
331
  GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
365
396
  for (int i = 0; i < ctx->n_buffers; ++i) {
366
397
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
367
398
 
399
+ //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
368
400
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
369
401
  *offs = (size_t) ioffs;
370
402
 
@@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
454
486
  }
455
487
  }
456
488
 
489
+ #if TARGET_OS_OSX
457
490
  metal_printf(", (%8.2f / %8.2f)",
458
491
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
459
492
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
463
496
  } else {
464
497
  metal_printf("\n");
465
498
  }
499
+ #else
500
+ metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
501
+ #endif
466
502
  }
467
503
 
468
504
  return true;
@@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
698
734
  case GGML_OP_ADD:
699
735
  {
700
736
  GGML_ASSERT(ggml_is_contiguous(src0));
737
+ GGML_ASSERT(ggml_is_contiguous(src1));
701
738
 
702
739
  // utilize float4
703
740
  GGML_ASSERT(ne00 % 4 == 0);
@@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
705
742
 
706
743
  if (ggml_nelements(src1) == ne10) {
707
744
  // src1 is a row
745
+ GGML_ASSERT(ne11 == 1);
708
746
  [encoder setComputePipelineState:ctx->pipeline_add_row];
709
747
  } else {
710
748
  [encoder setComputePipelineState:ctx->pipeline_add];
@@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
721
759
  case GGML_OP_MUL:
722
760
  {
723
761
  GGML_ASSERT(ggml_is_contiguous(src0));
762
+ GGML_ASSERT(ggml_is_contiguous(src1));
724
763
 
725
764
  // utilize float4
726
765
  GGML_ASSERT(ne00 % 4 == 0);
@@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
728
767
 
729
768
  if (ggml_nelements(src1) == ne10) {
730
769
  // src1 is a row
770
+ GGML_ASSERT(ne11 == 1);
731
771
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
732
772
  } else {
733
773
  [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
743
783
  } break;
744
784
  case GGML_OP_SCALE:
745
785
  {
786
+ GGML_ASSERT(ggml_is_contiguous(src0));
787
+
746
788
  const float scale = *(const float *) src1->data;
747
789
 
748
790
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
750
792
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
751
793
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
752
794
 
753
- const int64_t n = ggml_nelements(dst);
795
+ const int64_t n = ggml_nelements(dst)/4;
754
796
 
755
797
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
756
798
  } break;
@@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
762
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
763
805
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
764
806
 
765
- const int64_t n = ggml_nelements(dst);
807
+ const int64_t n = ggml_nelements(dst)/4;
766
808
 
767
809
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
768
810
  } break;
@@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
782
824
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
783
825
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
784
826
 
785
- const int64_t n = ggml_nelements(dst);
827
+ const int64_t n = ggml_nelements(dst)/4;
786
828
 
787
829
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
788
830
  } break;
@@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
796
838
  {
797
839
  const int nth = 32;
798
840
 
799
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
841
+ if (ne00%4 == 0) {
842
+ [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
843
+ } else {
844
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
845
+ }
800
846
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
801
847
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
802
848
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
803
849
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
804
850
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
805
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
806
851
 
807
852
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
808
853
  } break;
@@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
810
855
  {
811
856
  const int n_past = ((int32_t *)(dst->op_params))[0];
812
857
 
813
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
858
+ if (ne00%8 == 0) {
859
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
860
+ } else {
861
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
862
+ }
814
863
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
815
864
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
816
865
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
817
866
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
818
867
  [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
819
868
 
820
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
869
+ if (ne00%8 == 0) {
870
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
871
+ }
872
+ else {
873
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
874
+ }
821
875
  } break;
822
876
  case GGML_OP_MUL_MAT:
823
877
  {
@@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
830
884
 
831
885
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
832
886
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
833
- if (ggml_is_contiguous(src0) &&
834
- ggml_is_contiguous(src1) &&
887
+ if (!ggml_is_transposed(src0) &&
888
+ !ggml_is_transposed(src1) &&
835
889
  src1t == GGML_TYPE_F32 &&
836
890
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
837
891
  ne00%32 == 0 &&
838
892
  ne11 > 1) {
839
893
  switch (src0->type) {
894
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
840
895
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
841
896
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
842
897
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
856
911
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
857
912
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
858
913
  [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
859
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
860
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
861
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
914
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
915
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
916
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
917
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
918
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
919
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
862
920
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
863
921
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
864
922
  } else {
865
923
  int nth0 = 32;
866
924
  int nth1 = 1;
925
+ int nrows = 1;
867
926
 
868
927
  // use custom matrix x vector kernel
869
928
  switch (src0t) {
929
+ case GGML_TYPE_F32:
930
+ {
931
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
932
+ nrows = 4;
933
+ } break;
870
934
  case GGML_TYPE_F16:
871
935
  {
872
936
  nth0 = 32;
873
937
  nth1 = 1;
874
938
  if (ne11 * ne12 < 4) {
875
939
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
940
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
941
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
942
+ nrows = ne11;
876
943
  } else {
877
944
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
945
+ nrows = 4;
878
946
  }
879
947
  } break;
880
948
  case GGML_TYPE_Q4_0:
@@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
995
1063
  else if (src0t == GGML_TYPE_Q6_K) {
996
1064
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
997
1065
  } else {
998
- int64_t ny = (ne11 + 3)/4;
1066
+ int64_t ny = (ne11 + nrows - 1)/nrows;
999
1067
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1000
1068
  }
1001
1069
  }
@@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
1003
1071
  case GGML_OP_GET_ROWS:
1004
1072
  {
1005
1073
  switch (src0->type) {
1074
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1006
1075
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1007
1076
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1008
1077
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
1018
1087
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1019
1088
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1020
1089
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1021
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
1022
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
1023
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
1090
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1091
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1092
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1024
1093
 
1025
1094
  const int64_t n = ggml_nelements(src1);
1026
1095