llama_cpp 0.5.1 → 0.5.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -63,7 +63,10 @@ struct ggml_metal_context {
63
63
  GGML_METAL_DECL_KERNEL(relu);
64
64
  GGML_METAL_DECL_KERNEL(gelu);
65
65
  GGML_METAL_DECL_KERNEL(soft_max);
66
+ GGML_METAL_DECL_KERNEL(soft_max_4);
66
67
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
68
+ GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
69
+ GGML_METAL_DECL_KERNEL(get_rows_f32);
67
70
  GGML_METAL_DECL_KERNEL(get_rows_f16);
68
71
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
69
72
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -75,8 +78,10 @@ struct ggml_metal_context {
75
78
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
76
79
  GGML_METAL_DECL_KERNEL(rms_norm);
77
80
  GGML_METAL_DECL_KERNEL(norm);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
78
82
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79
83
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
80
85
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
81
86
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
82
87
  GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -85,6 +90,7 @@ struct ggml_metal_context {
85
90
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
86
91
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
87
92
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
88
94
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
89
95
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
90
96
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
117
123
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
118
124
  metal_printf("%s: allocating\n", __func__);
119
125
 
120
- // Show all the Metal device instances in the system
121
- NSArray * devices = MTLCopyAllDevices();
122
126
  id <MTLDevice> device;
123
127
  NSString * s;
128
+
129
+ #if TARGET_OS_OSX
130
+ // Show all the Metal device instances in the system
131
+ NSArray * devices = MTLCopyAllDevices();
124
132
  for (device in devices) {
125
133
  s = [device name];
126
134
  metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
127
135
  }
136
+ #endif
128
137
 
129
138
  // Pick and show default Metal device
130
139
  device = MTLCreateSystemDefaultDevice();
@@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
139
148
  ctx->n_buffers = 0;
140
149
  ctx->concur_list_len = 0;
141
150
 
142
- ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
151
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
143
152
 
144
- #if 0
145
- // compile from source string and show compile log
153
+ #ifdef GGML_SWIFT
154
+ // load the default.metallib file
146
155
  {
147
156
  NSError * error = nil;
148
157
 
149
- ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
158
+ NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
159
+ NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
160
+ NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
161
+ NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
162
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
163
+
164
+ // Load the metallib file into a Metal library
165
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
+
150
167
  if (error) {
151
168
  metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
152
169
  return NULL;
@@ -161,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
161
178
 
162
179
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
163
180
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
164
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
181
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
165
182
  metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
166
183
 
167
184
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
207
224
  GGML_METAL_ADD_KERNEL(relu);
208
225
  GGML_METAL_ADD_KERNEL(gelu);
209
226
  GGML_METAL_ADD_KERNEL(soft_max);
227
+ GGML_METAL_ADD_KERNEL(soft_max_4);
210
228
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
229
+ GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
230
+ GGML_METAL_ADD_KERNEL(get_rows_f32);
211
231
  GGML_METAL_ADD_KERNEL(get_rows_f16);
212
232
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
213
233
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
219
239
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
220
240
  GGML_METAL_ADD_KERNEL(rms_norm);
221
241
  GGML_METAL_ADD_KERNEL(norm);
242
+ GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
222
243
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223
244
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
224
246
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
225
247
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
226
248
  GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
229
251
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
230
252
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
231
253
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
232
255
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
233
256
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
234
257
  GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
247
270
  #undef GGML_METAL_ADD_KERNEL
248
271
  }
249
272
 
250
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
251
273
  metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
+ #if TARGET_OS_OSX
275
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
252
276
  if (ctx->device.maxTransferRate != 0) {
253
277
  metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
254
278
  } else {
255
279
  metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
256
280
  }
281
+ #endif
257
282
 
258
283
  return ctx;
259
284
  }
@@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
273
298
  GGML_METAL_DEL_KERNEL(relu);
274
299
  GGML_METAL_DEL_KERNEL(gelu);
275
300
  GGML_METAL_DEL_KERNEL(soft_max);
301
+ GGML_METAL_DEL_KERNEL(soft_max_4);
276
302
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
303
+ GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
304
+ GGML_METAL_DEL_KERNEL(get_rows_f32);
277
305
  GGML_METAL_DEL_KERNEL(get_rows_f16);
278
306
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
279
307
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
285
313
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
286
314
  GGML_METAL_DEL_KERNEL(rms_norm);
287
315
  GGML_METAL_DEL_KERNEL(norm);
316
+ GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
288
317
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289
318
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
319
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
290
320
  GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291
321
  GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292
322
  GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
295
325
  GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
296
326
  GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
297
327
  GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
328
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
298
329
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
299
330
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
300
331
  GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
365
396
  for (int i = 0; i < ctx->n_buffers; ++i) {
366
397
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
367
398
 
399
+ //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
368
400
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
369
401
  *offs = (size_t) ioffs;
370
402
 
@@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
454
486
  }
455
487
  }
456
488
 
489
+ #if TARGET_OS_OSX
457
490
  metal_printf(", (%8.2f / %8.2f)",
458
491
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
459
492
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
463
496
  } else {
464
497
  metal_printf("\n");
465
498
  }
499
+ #else
500
+ metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
501
+ #endif
466
502
  }
467
503
 
468
504
  return true;
@@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
698
734
  case GGML_OP_ADD:
699
735
  {
700
736
  GGML_ASSERT(ggml_is_contiguous(src0));
737
+ GGML_ASSERT(ggml_is_contiguous(src1));
701
738
 
702
739
  // utilize float4
703
740
  GGML_ASSERT(ne00 % 4 == 0);
@@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
705
742
 
706
743
  if (ggml_nelements(src1) == ne10) {
707
744
  // src1 is a row
745
+ GGML_ASSERT(ne11 == 1);
708
746
  [encoder setComputePipelineState:ctx->pipeline_add_row];
709
747
  } else {
710
748
  [encoder setComputePipelineState:ctx->pipeline_add];
@@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
721
759
  case GGML_OP_MUL:
722
760
  {
723
761
  GGML_ASSERT(ggml_is_contiguous(src0));
762
+ GGML_ASSERT(ggml_is_contiguous(src1));
724
763
 
725
764
  // utilize float4
726
765
  GGML_ASSERT(ne00 % 4 == 0);
@@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
728
767
 
729
768
  if (ggml_nelements(src1) == ne10) {
730
769
  // src1 is a row
770
+ GGML_ASSERT(ne11 == 1);
731
771
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
732
772
  } else {
733
773
  [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
743
783
  } break;
744
784
  case GGML_OP_SCALE:
745
785
  {
786
+ GGML_ASSERT(ggml_is_contiguous(src0));
787
+
746
788
  const float scale = *(const float *) src1->data;
747
789
 
748
790
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
750
792
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
751
793
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
752
794
 
753
- const int64_t n = ggml_nelements(dst);
795
+ const int64_t n = ggml_nelements(dst)/4;
754
796
 
755
797
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
756
798
  } break;
@@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
762
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
763
805
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
764
806
 
765
- const int64_t n = ggml_nelements(dst);
807
+ const int64_t n = ggml_nelements(dst)/4;
766
808
 
767
809
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
768
810
  } break;
@@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
782
824
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
783
825
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
784
826
 
785
- const int64_t n = ggml_nelements(dst);
827
+ const int64_t n = ggml_nelements(dst)/4;
786
828
 
787
829
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
788
830
  } break;
@@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
796
838
  {
797
839
  const int nth = 32;
798
840
 
799
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
841
+ if (ne00%4 == 0) {
842
+ [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
843
+ } else {
844
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
845
+ }
800
846
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
801
847
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
802
848
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
803
849
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
804
850
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
805
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
806
851
 
807
852
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
808
853
  } break;
@@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
810
855
  {
811
856
  const int n_past = ((int32_t *)(dst->op_params))[0];
812
857
 
813
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
858
+ if (ne00%8 == 0) {
859
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
860
+ } else {
861
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
862
+ }
814
863
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
815
864
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
816
865
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
817
866
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
818
867
  [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
819
868
 
820
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
869
+ if (ne00%8 == 0) {
870
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
871
+ }
872
+ else {
873
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
874
+ }
821
875
  } break;
822
876
  case GGML_OP_MUL_MAT:
823
877
  {
@@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
830
884
 
831
885
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
832
886
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
833
- if (ggml_is_contiguous(src0) &&
834
- ggml_is_contiguous(src1) &&
887
+ if (!ggml_is_transposed(src0) &&
888
+ !ggml_is_transposed(src1) &&
835
889
  src1t == GGML_TYPE_F32 &&
836
890
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
837
891
  ne00%32 == 0 &&
838
892
  ne11 > 1) {
839
893
  switch (src0->type) {
894
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
840
895
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
841
896
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
842
897
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
856
911
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
857
912
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
858
913
  [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
859
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
860
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
861
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
914
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
915
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
916
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
917
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
918
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
919
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
862
920
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
863
921
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
864
922
  } else {
865
923
  int nth0 = 32;
866
924
  int nth1 = 1;
925
+ int nrows = 1;
867
926
 
868
927
  // use custom matrix x vector kernel
869
928
  switch (src0t) {
929
+ case GGML_TYPE_F32:
930
+ {
931
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
932
+ nrows = 4;
933
+ } break;
870
934
  case GGML_TYPE_F16:
871
935
  {
872
936
  nth0 = 32;
873
937
  nth1 = 1;
874
938
  if (ne11 * ne12 < 4) {
875
939
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
940
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
941
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
942
+ nrows = ne11;
876
943
  } else {
877
944
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
945
+ nrows = 4;
878
946
  }
879
947
  } break;
880
948
  case GGML_TYPE_Q4_0:
@@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
995
1063
  else if (src0t == GGML_TYPE_Q6_K) {
996
1064
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
997
1065
  } else {
998
- int64_t ny = (ne11 + 3)/4;
1066
+ int64_t ny = (ne11 + nrows - 1)/nrows;
999
1067
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1000
1068
  }
1001
1069
  }
@@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
1003
1071
  case GGML_OP_GET_ROWS:
1004
1072
  {
1005
1073
  switch (src0->type) {
1074
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1006
1075
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1007
1076
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1008
1077
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
1018
1087
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1019
1088
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1020
1089
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1021
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
1022
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
1023
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
1090
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1091
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1092
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1024
1093
 
1025
1094
  const int64_t n = ggml_nelements(src1);
1026
1095