llama_cpp 0.3.4 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -637,7 +762,7 @@ void ggml_metal_graph_compute(
637
762
  }
638
763
  } else {
639
764
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
765
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
766
  }
642
767
 
643
768
  int nth0 = 32;
@@ -764,7 +889,7 @@ void ggml_metal_graph_compute(
764
889
  case GGML_OP_GET_ROWS:
765
890
  {
766
891
  if (encoder == nil) {
767
- encoder = [command_buffer computeCommandEncoder];
892
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
768
893
  }
769
894
 
770
895
  switch (src0->type) {
@@ -793,10 +918,11 @@ void ggml_metal_graph_compute(
793
918
  case GGML_OP_RMS_NORM:
794
919
  {
795
920
  if (encoder == nil) {
796
- encoder = [command_buffer computeCommandEncoder];
921
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
797
922
  }
798
923
 
799
- const float eps = 1e-6f;
924
+ float eps;
925
+ memcpy(&eps, dst->op_params, sizeof(float));
800
926
 
801
927
  const int nth = 512;
802
928
 
@@ -815,7 +941,7 @@ void ggml_metal_graph_compute(
815
941
  case GGML_OP_NORM:
816
942
  {
817
943
  if (encoder == nil) {
818
- encoder = [command_buffer computeCommandEncoder];
944
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
819
945
  }
820
946
 
821
947
  const float eps = 1e-5f;
@@ -837,14 +963,15 @@ void ggml_metal_graph_compute(
837
963
  case GGML_OP_ALIBI:
838
964
  {
839
965
  if (encoder == nil) {
840
- encoder = [command_buffer computeCommandEncoder];
966
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
841
967
  }
842
968
 
843
969
  GGML_ASSERT((src0t == GGML_TYPE_F32));
844
970
 
845
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
846
- const int n_head = ((int32_t *) src1->data)[1];
847
- const float max_bias = ((float *) src1->data)[2];
971
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
972
+ const int n_head = ((int32_t *) dst->op_params)[1];
973
+ float max_bias;
974
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
848
975
 
849
976
  if (__builtin_popcount(n_head) != 1) {
850
977
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -879,18 +1006,17 @@ void ggml_metal_graph_compute(
879
1006
  case GGML_OP_ROPE:
880
1007
  {
881
1008
  if (encoder == nil) {
882
- encoder = [command_buffer computeCommandEncoder];
1009
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
883
1010
  }
884
1011
 
885
- const int n_dims = ((int32_t *) src1->data)[1];
886
- const int mode = ((int32_t *) src1->data)[2];
887
-
888
- const int n_past = ((int32_t *)(src1->data))[0];
1012
+ const int n_past = ((int32_t *) dst->op_params)[0];
1013
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1014
+ const int mode = ((int32_t *) dst->op_params)[2];
889
1015
 
890
1016
  float freq_base;
891
1017
  float freq_scale;
892
- memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
- memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1018
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1019
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
894
1020
 
895
1021
  [encoder setComputePipelineState:ctx->pipeline_rope];
896
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -919,10 +1045,12 @@ void ggml_metal_graph_compute(
919
1045
 
920
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
921
1047
  } break;
1048
+ case GGML_OP_DUP:
922
1049
  case GGML_OP_CPY:
1050
+ case GGML_OP_CONT:
923
1051
  {
924
1052
  if (encoder == nil) {
925
- encoder = [command_buffer computeCommandEncoder];
1053
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
926
1054
  }
927
1055
 
928
1056
  const int nth = 32;
@@ -969,8 +1097,10 @@ void ggml_metal_graph_compute(
969
1097
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
970
1098
  } break;
971
1099
  default:
972
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
973
- GGML_ASSERT(false);
1100
+ {
1101
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1102
+ GGML_ASSERT(false);
1103
+ }
974
1104
  }
975
1105
  }
976
1106
 
@@ -67,6 +67,17 @@ kernel void kernel_add(
67
67
  dst[tpig] = src0[tpig] + src1[tpig];
68
68
  }
69
69
 
70
+ // assumption: src1 is a row
71
+ // broadcast src1 into src0
72
+ kernel void kernel_add_row(
73
+ device const float * src0,
74
+ device const float * src1,
75
+ device float * dst,
76
+ constant int64_t & ne00,
77
+ uint tpig[[thread_position_in_grid]]) {
78
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
79
+ }
80
+
70
81
  kernel void kernel_mul(
71
82
  device const float * src0,
72
83
  device const float * src1,
@@ -376,87 +387,90 @@ kernel void kernel_rms_norm(
376
387
  }
377
388
  }
378
389
 
379
- // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
- float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
390
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
392
+ // we assume that the yl's have been multiplied with the appropriate scale factor
393
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
381
395
  float d = qb_curr->d;
382
- float4 acc = 0.f;
383
- device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
- for (int i = 0; i < 16; i+=2) {
385
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
396
+ float2 acc = 0.f;
397
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
398
+ for (int i = 0; i < 8; i+=2) {
399
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
400
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
401
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
402
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
389
403
  }
390
- return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
404
+ return d * (sumy * -8.f + acc[0] + acc[1]);
391
405
  }
392
406
 
393
- // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
- float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
407
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
409
+ // we assume that the yl's have been multiplied with the appropriate scale factor
410
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
395
412
  float d = qb_curr->d;
396
413
  float m = qb_curr->m;
397
- float4 acc = 0.f;
398
- device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
- for (int i = 0; i < 16; i+=2) {
400
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
414
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
415
+ float2 acc = 0.f;
416
+ for (int i = 0; i < 8; i+=2) {
417
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
418
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
419
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
420
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
404
421
  }
405
- return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
422
+ return d * (acc[0] + acc[1]) + sumy * m;
406
423
  }
407
424
 
408
425
  // putting them in the kernel cause a significant performance penalty
409
426
  #define N_DST 4 // each SIMD group works on 4 rows
410
427
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
411
428
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
412
- template<typename block_q_type>
429
+ //Note: This is a template, but strictly speaking it only applies to
430
+ // quantizations where the block size is 32. It also does not
431
+ // giard against the number of rows not being divisible by
432
+ // N_DST, so this is another explicit assumption of the implementation.
433
+ template<typename block_q_type, int nr, int nsg, int nw>
413
434
  void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
435
  int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
436
  uint2 tgpig, uint tiisg, uint sgitg) {
416
437
  const int nb = ne00/QK4_0;
417
438
  const int r0 = tgpig.x;
418
439
  const int r1 = tgpig.y;
419
- device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440
+ const int first_row = (r0 * nsg + sgitg) * nr;
441
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
420
442
  device const float * y = (device const float *) src1 + r1*ne10;
421
- float4 y_curr[8]; // src1 vector cache
422
- float sumf[N_DST]={0.f}, all_sum;
423
- thread float * yl=(thread float *)y_curr;
443
+ float yl[16]; // src1 vector cache
444
+ float sumf[nr]={0.f};
424
445
 
425
- // each thread in a SIMD group deals with 1 block.
426
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
427
- float sumy = 0;
428
- for (int i = 0; i < QK4_0 / 4; i++) {
429
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
431
- }
446
+ const int ix = tiisg/2;
447
+ const int il = 8*(tiisg%2);
432
448
 
433
- for (int row = 0; row < N_DST; row++) {
434
- sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
435
- }
436
- }
449
+ device const float * yb = y + ix * QK4_0 + il;
437
450
 
438
- // from now loads two rows every time and 16 blocks per row
439
- int ir = tiisg / (N_SIMDWIDTH / 2);
440
- int ib = tiisg % (N_SIMDWIDTH / 2);
441
- for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
- int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
451
+ // each thread in a SIMD group deals with half a block.
452
+ for (int ib = ix; ib < nb; ib += nw/2) {
443
453
  float sumy = 0;
444
- for (int i = 0; i < QK4_0 / 4; i++) {
445
- y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
454
+ for (int i = 0; i < 8; i += 2) {
455
+ sumy += yb[i] + yb[i+1];
456
+ yl[i+0] = yb[i+ 0];
457
+ yl[i+1] = yb[i+ 1]/256.f;
458
+ sumy += yb[i+16] + yb[i+17];
459
+ yl[i+8] = yb[i+16]/16.f;
460
+ yl[i+9] = yb[i+17]/4096.f;
447
461
  }
448
462
 
449
- for (int row = 0; row < N_DST; row+=2) {
450
- if (nb_start + ib < nb) {
451
- sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
- }
463
+ for (int row = 0; row < nr; row++) {
464
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
453
465
  }
466
+
467
+ yb += QK4_0 * 16;
454
468
  }
455
469
 
456
- for (int row = 0; row < N_DST; ++row) {
457
- all_sum = simd_sum(sumf[row]);
458
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
470
+ for (int row = 0; row < nr; ++row) {
471
+ const float tot = simd_sum(sumf[row]);
472
+ if (tiisg == 0 && first_row + row < ne01) {
473
+ dst[r1*ne0 + first_row + row] = tot;
460
474
  }
461
475
  }
462
476
  }
@@ -472,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
472
486
  uint2 tgpig[[threadgroup_position_in_grid]],
473
487
  uint tiisg[[thread_index_in_simdgroup]],
474
488
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
- mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
489
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
490
  }
477
491
 
478
492
  kernel void kernel_mul_mat_q4_1_f32(
@@ -486,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
486
500
  uint2 tgpig[[threadgroup_position_in_grid]],
487
501
  uint tiisg[[thread_index_in_simdgroup]],
488
502
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
- mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
490
504
  }
491
505
 
492
506
  kernel void kernel_mul_mat_f16_f32(