llama_cpp 0.3.4 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -637,7 +762,7 @@ void ggml_metal_graph_compute(
637
762
  }
638
763
  } else {
639
764
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
765
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
766
  }
642
767
 
643
768
  int nth0 = 32;
@@ -764,7 +889,7 @@ void ggml_metal_graph_compute(
764
889
  case GGML_OP_GET_ROWS:
765
890
  {
766
891
  if (encoder == nil) {
767
- encoder = [command_buffer computeCommandEncoder];
892
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
768
893
  }
769
894
 
770
895
  switch (src0->type) {
@@ -793,10 +918,11 @@ void ggml_metal_graph_compute(
793
918
  case GGML_OP_RMS_NORM:
794
919
  {
795
920
  if (encoder == nil) {
796
- encoder = [command_buffer computeCommandEncoder];
921
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
797
922
  }
798
923
 
799
- const float eps = 1e-6f;
924
+ float eps;
925
+ memcpy(&eps, dst->op_params, sizeof(float));
800
926
 
801
927
  const int nth = 512;
802
928
 
@@ -815,7 +941,7 @@ void ggml_metal_graph_compute(
815
941
  case GGML_OP_NORM:
816
942
  {
817
943
  if (encoder == nil) {
818
- encoder = [command_buffer computeCommandEncoder];
944
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
819
945
  }
820
946
 
821
947
  const float eps = 1e-5f;
@@ -837,14 +963,15 @@ void ggml_metal_graph_compute(
837
963
  case GGML_OP_ALIBI:
838
964
  {
839
965
  if (encoder == nil) {
840
- encoder = [command_buffer computeCommandEncoder];
966
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
841
967
  }
842
968
 
843
969
  GGML_ASSERT((src0t == GGML_TYPE_F32));
844
970
 
845
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
846
- const int n_head = ((int32_t *) src1->data)[1];
847
- const float max_bias = ((float *) src1->data)[2];
971
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
972
+ const int n_head = ((int32_t *) dst->op_params)[1];
973
+ float max_bias;
974
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
848
975
 
849
976
  if (__builtin_popcount(n_head) != 1) {
850
977
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -879,18 +1006,17 @@ void ggml_metal_graph_compute(
879
1006
  case GGML_OP_ROPE:
880
1007
  {
881
1008
  if (encoder == nil) {
882
- encoder = [command_buffer computeCommandEncoder];
1009
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
883
1010
  }
884
1011
 
885
- const int n_dims = ((int32_t *) src1->data)[1];
886
- const int mode = ((int32_t *) src1->data)[2];
887
-
888
- const int n_past = ((int32_t *)(src1->data))[0];
1012
+ const int n_past = ((int32_t *) dst->op_params)[0];
1013
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1014
+ const int mode = ((int32_t *) dst->op_params)[2];
889
1015
 
890
1016
  float freq_base;
891
1017
  float freq_scale;
892
- memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
- memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1018
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1019
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
894
1020
 
895
1021
  [encoder setComputePipelineState:ctx->pipeline_rope];
896
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -919,10 +1045,12 @@ void ggml_metal_graph_compute(
919
1045
 
920
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
921
1047
  } break;
1048
+ case GGML_OP_DUP:
922
1049
  case GGML_OP_CPY:
1050
+ case GGML_OP_CONT:
923
1051
  {
924
1052
  if (encoder == nil) {
925
- encoder = [command_buffer computeCommandEncoder];
1053
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
926
1054
  }
927
1055
 
928
1056
  const int nth = 32;
@@ -969,8 +1097,10 @@ void ggml_metal_graph_compute(
969
1097
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
970
1098
  } break;
971
1099
  default:
972
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
973
- GGML_ASSERT(false);
1100
+ {
1101
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1102
+ GGML_ASSERT(false);
1103
+ }
974
1104
  }
975
1105
  }
976
1106
 
@@ -67,6 +67,17 @@ kernel void kernel_add(
67
67
  dst[tpig] = src0[tpig] + src1[tpig];
68
68
  }
69
69
 
70
+ // assumption: src1 is a row
71
+ // broadcast src1 into src0
72
+ kernel void kernel_add_row(
73
+ device const float * src0,
74
+ device const float * src1,
75
+ device float * dst,
76
+ constant int64_t & ne00,
77
+ uint tpig[[thread_position_in_grid]]) {
78
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
79
+ }
80
+
70
81
  kernel void kernel_mul(
71
82
  device const float * src0,
72
83
  device const float * src1,
@@ -376,87 +387,90 @@ kernel void kernel_rms_norm(
376
387
  }
377
388
  }
378
389
 
379
- // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
380
- float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
390
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
391
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
392
+ // we assume that the yl's have been multiplied with the appropriate scale factor
393
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
394
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
381
395
  float d = qb_curr->d;
382
- float4 acc = 0.f;
383
- device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
384
- for (int i = 0; i < 16; i+=2) {
385
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
386
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
387
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
388
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
396
+ float2 acc = 0.f;
397
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
398
+ for (int i = 0; i < 8; i+=2) {
399
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
400
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
401
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
402
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
389
403
  }
390
- return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
404
+ return d * (sumy * -8.f + acc[0] + acc[1]);
391
405
  }
392
406
 
393
- // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
394
- float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
407
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
408
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
409
+ // we assume that the yl's have been multiplied with the appropriate scale factor
410
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
411
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
395
412
  float d = qb_curr->d;
396
413
  float m = qb_curr->m;
397
- float4 acc = 0.f;
398
- device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
399
- for (int i = 0; i < 16; i+=2) {
400
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
401
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
402
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
403
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
414
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
415
+ float2 acc = 0.f;
416
+ for (int i = 0; i < 8; i+=2) {
417
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
418
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
419
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
420
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
404
421
  }
405
- return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
422
+ return d * (acc[0] + acc[1]) + sumy * m;
406
423
  }
407
424
 
408
425
  // putting them in the kernel cause a significant performance penalty
409
426
  #define N_DST 4 // each SIMD group works on 4 rows
410
427
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
411
428
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
412
- template<typename block_q_type>
429
+ //Note: This is a template, but strictly speaking it only applies to
430
+ // quantizations where the block size is 32. It also does not
431
+ // giard against the number of rows not being divisible by
432
+ // N_DST, so this is another explicit assumption of the implementation.
433
+ template<typename block_q_type, int nr, int nsg, int nw>
413
434
  void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
414
435
  int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
415
436
  uint2 tgpig, uint tiisg, uint sgitg) {
416
437
  const int nb = ne00/QK4_0;
417
438
  const int r0 = tgpig.x;
418
439
  const int r1 = tgpig.y;
419
- device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
440
+ const int first_row = (r0 * nsg + sgitg) * nr;
441
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
420
442
  device const float * y = (device const float *) src1 + r1*ne10;
421
- float4 y_curr[8]; // src1 vector cache
422
- float sumf[N_DST]={0.f}, all_sum;
423
- thread float * yl=(thread float *)y_curr;
443
+ float yl[16]; // src1 vector cache
444
+ float sumf[nr]={0.f};
424
445
 
425
- // each thread in a SIMD group deals with 1 block.
426
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
427
- float sumy = 0;
428
- for (int i = 0; i < QK4_0 / 4; i++) {
429
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
430
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
431
- }
446
+ const int ix = tiisg/2;
447
+ const int il = 8*(tiisg%2);
432
448
 
433
- for (int row = 0; row < N_DST; row++) {
434
- sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
435
- }
436
- }
449
+ device const float * yb = y + ix * QK4_0 + il;
437
450
 
438
- // from now loads two rows every time and 16 blocks per row
439
- int ir = tiisg / (N_SIMDWIDTH / 2);
440
- int ib = tiisg % (N_SIMDWIDTH / 2);
441
- for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
442
- int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
451
+ // each thread in a SIMD group deals with half a block.
452
+ for (int ib = ix; ib < nb; ib += nw/2) {
443
453
  float sumy = 0;
444
- for (int i = 0; i < QK4_0 / 4; i++) {
445
- y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
446
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
454
+ for (int i = 0; i < 8; i += 2) {
455
+ sumy += yb[i] + yb[i+1];
456
+ yl[i+0] = yb[i+ 0];
457
+ yl[i+1] = yb[i+ 1]/256.f;
458
+ sumy += yb[i+16] + yb[i+17];
459
+ yl[i+8] = yb[i+16]/16.f;
460
+ yl[i+9] = yb[i+17]/4096.f;
447
461
  }
448
462
 
449
- for (int row = 0; row < N_DST; row+=2) {
450
- if (nb_start + ib < nb) {
451
- sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
452
- }
463
+ for (int row = 0; row < nr; row++) {
464
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
453
465
  }
466
+
467
+ yb += QK4_0 * 16;
454
468
  }
455
469
 
456
- for (int row = 0; row < N_DST; ++row) {
457
- all_sum = simd_sum(sumf[row]);
458
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
459
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
470
+ for (int row = 0; row < nr; ++row) {
471
+ const float tot = simd_sum(sumf[row]);
472
+ if (tiisg == 0 && first_row + row < ne01) {
473
+ dst[r1*ne0 + first_row + row] = tot;
460
474
  }
461
475
  }
462
476
  }
@@ -472,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
472
486
  uint2 tgpig[[threadgroup_position_in_grid]],
473
487
  uint tiisg[[thread_index_in_simdgroup]],
474
488
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
475
- mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
489
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
476
490
  }
477
491
 
478
492
  kernel void kernel_mul_mat_q4_1_f32(
@@ -486,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
486
500
  uint2 tgpig[[threadgroup_position_in_grid]],
487
501
  uint tiisg[[thread_index_in_simdgroup]],
488
502
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
489
- mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
503
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
490
504
  }
491
505
 
492
506
  kernel void kernel_mul_mat_f16_f32(