llama_cpp 0.3.3 → 0.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/ext/llama_cpp/extconf.rb +1 -0
- data/ext/llama_cpp/llama_cpp.cpp +439 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +759 -136
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +250 -111
- data/ext/llama_cpp/src/ggml-metal.metal +614 -483
- data/ext/llama_cpp/src/ggml.c +793 -1032
- data/ext/llama_cpp/src/ggml.h +95 -18
- data/ext/llama_cpp/src/k_quants.c +327 -3
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +626 -166
- data/ext/llama_cpp/src/llama.h +94 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +36 -1
- metadata +2 -2
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
|
61
61
|
// get data from the device into host memory
|
62
62
|
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
63
63
|
|
64
|
+
// try to find operations that can be run concurrently in the graph
|
65
|
+
// you should run it again if the topology of your graph changes
|
66
|
+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
67
|
+
|
68
|
+
// if the graph has been optimized for concurrently dispatch
|
69
|
+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
70
|
+
|
64
71
|
// same as ggml_graph_compute but uses Metal
|
65
72
|
// creates gf->n_threads command buffers in parallel
|
66
73
|
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
@@ -36,12 +36,16 @@ struct ggml_metal_context {
|
|
36
36
|
int n_buffers;
|
37
37
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
38
38
|
|
39
|
+
int concur_list[GGML_MAX_NODES];
|
40
|
+
int concur_list_len;
|
41
|
+
|
39
42
|
// custom kernels
|
40
43
|
#define GGML_METAL_DECL_KERNEL(name) \
|
41
44
|
id<MTLFunction> function_##name; \
|
42
45
|
id<MTLComputePipelineState> pipeline_##name
|
43
46
|
|
44
47
|
GGML_METAL_DECL_KERNEL(add);
|
48
|
+
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
45
49
|
GGML_METAL_DECL_KERNEL(mul);
|
46
50
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
47
51
|
GGML_METAL_DECL_KERNEL(scale);
|
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
97
101
|
ctx->device = MTLCreateSystemDefaultDevice();
|
98
102
|
ctx->queue = [ctx->device newCommandQueue];
|
99
103
|
ctx->n_buffers = 0;
|
104
|
+
ctx->concur_list_len = 0;
|
100
105
|
|
101
106
|
// determine if we can use MPS
|
102
107
|
if (MPSSupportsMTLDevice(ctx->device)) {
|
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
157
162
|
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
158
163
|
|
159
164
|
GGML_METAL_ADD_KERNEL(add);
|
165
|
+
GGML_METAL_ADD_KERNEL(add_row);
|
160
166
|
GGML_METAL_ADD_KERNEL(mul);
|
161
167
|
GGML_METAL_ADD_KERNEL(mul_row);
|
162
168
|
GGML_METAL_ADD_KERNEL(scale);
|
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
215
221
|
ctx->n_cb = n_cb;
|
216
222
|
}
|
217
223
|
|
224
|
+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
225
|
+
if (ctx->concur_list_len) {
|
226
|
+
return true;
|
227
|
+
}
|
228
|
+
return false;
|
229
|
+
}
|
230
|
+
|
218
231
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
219
232
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
220
233
|
// Metal buffer based on the host memory pointer
|
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
|
|
353
366
|
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
354
367
|
}
|
355
368
|
|
369
|
+
void ggml_metal_graph_find_concurrency(
|
370
|
+
struct ggml_metal_context * ctx,
|
371
|
+
struct ggml_cgraph * gf) {
|
372
|
+
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
373
|
+
int nodes_unused[GGML_MAX_NODES];
|
374
|
+
|
375
|
+
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
|
376
|
+
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
|
377
|
+
ctx->concur_list_len = 0;
|
378
|
+
|
379
|
+
int n_left = gf->n_nodes;
|
380
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
381
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
382
|
+
|
383
|
+
while (n_left > 0) {
|
384
|
+
// number of nodes at a layer (that can be issued concurrently)
|
385
|
+
int concurrency = 0;
|
386
|
+
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
387
|
+
if (nodes_unused[i]) {
|
388
|
+
// if the requirements for gf->nodes[i] are satisfied
|
389
|
+
int exe_flag=1;
|
390
|
+
// scan all srcs
|
391
|
+
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
392
|
+
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
393
|
+
if (src_cur) {
|
394
|
+
// if is leaf nodes it's satisfied.
|
395
|
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
|
396
|
+
|
397
|
+
// otherwise this src should be the output from previous nodes.
|
398
|
+
int is_found = 0;
|
399
|
+
// scan 2*search_depth back because we inserted barrier.
|
400
|
+
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
401
|
+
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
|
402
|
+
}
|
403
|
+
if (is_found == 0) {exe_flag = 0; break;}
|
404
|
+
}
|
405
|
+
}
|
406
|
+
if (exe_flag) {
|
407
|
+
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
408
|
+
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
409
|
+
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
410
|
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
411
|
+
for (int j = n_start; j < i; j++) {
|
412
|
+
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
413
|
+
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
414
|
+
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
415
|
+
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
416
|
+
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
417
|
+
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
418
|
+
continue;
|
419
|
+
} else {
|
420
|
+
exe_flag = 0;
|
421
|
+
}
|
422
|
+
}
|
423
|
+
}
|
424
|
+
}
|
425
|
+
if (exe_flag) {
|
426
|
+
ctx->concur_list[level_pos + concurrency] = i;
|
427
|
+
nodes_unused[i] = 0;
|
428
|
+
concurrency++;
|
429
|
+
ctx->concur_list_len++;
|
430
|
+
}
|
431
|
+
}
|
432
|
+
}
|
433
|
+
n_left -= concurrency;
|
434
|
+
// adding a barrier different layer
|
435
|
+
ctx->concur_list[level_pos + concurrency] = -1;
|
436
|
+
ctx->concur_list_len++;
|
437
|
+
// jump all sorted nodes at nodes_bak
|
438
|
+
while (!nodes_unused[n_start]) {n_start++;}
|
439
|
+
level_pos += concurrency + 1;
|
440
|
+
}
|
441
|
+
|
442
|
+
if (ctx->concur_list_len > GGML_MAX_NODES) {
|
443
|
+
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
444
|
+
}
|
445
|
+
}
|
446
|
+
|
356
447
|
void ggml_metal_graph_compute(
|
357
448
|
struct ggml_metal_context * ctx,
|
358
449
|
struct ggml_cgraph * gf) {
|
359
450
|
metal_printf("%s: evaluating graph\n", __func__);
|
360
451
|
|
452
|
+
// if there is ctx->concur_list, dispatch concurrently
|
453
|
+
// else fallback to serial dispatch
|
454
|
+
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
455
|
+
|
456
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
457
|
+
|
458
|
+
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
459
|
+
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
460
|
+
|
361
461
|
// create multiple command buffers and enqueue them
|
362
462
|
// then, we encode the graph into the command buffers in parallel
|
363
463
|
|
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
|
|
376
476
|
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
377
477
|
|
378
478
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
379
|
-
const int n_nodes_per_cb = (
|
479
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
380
480
|
|
381
481
|
dispatch_async(queue, ^{
|
382
482
|
size_t offs_src0 = 0;
|
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
|
|
387
487
|
|
388
488
|
id<MTLComputeCommandEncoder> encoder = nil;
|
389
489
|
|
390
|
-
const int node_start =
|
391
|
-
const int node_end = (cb_idx == n_cb - 1) ?
|
490
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
491
|
+
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
492
|
+
|
493
|
+
for (int ind = node_start; ind < node_end; ++ind) {
|
494
|
+
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
495
|
+
|
496
|
+
if (i == -1) {
|
497
|
+
if (encoder == nil) {
|
498
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
499
|
+
continue;
|
500
|
+
}
|
501
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
502
|
+
continue;
|
503
|
+
}
|
392
504
|
|
393
|
-
for (int i = node_start; i < node_end; ++i) {
|
394
505
|
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
395
506
|
|
396
507
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
|
|
461
572
|
case GGML_OP_ADD:
|
462
573
|
{
|
463
574
|
if (encoder == nil) {
|
464
|
-
encoder = [command_buffer
|
575
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
465
576
|
}
|
466
577
|
|
467
|
-
|
578
|
+
if (ggml_nelements(src1) == ne10) {
|
579
|
+
// src1 is a row
|
580
|
+
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
581
|
+
} else {
|
582
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
583
|
+
}
|
468
584
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
469
585
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
470
586
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
587
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
471
588
|
|
472
589
|
const int64_t n = ggml_nelements(dst);
|
473
590
|
|
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
|
|
476
593
|
case GGML_OP_MUL:
|
477
594
|
{
|
478
595
|
if (encoder == nil) {
|
479
|
-
encoder = [command_buffer
|
596
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
480
597
|
}
|
481
598
|
|
482
599
|
if (ggml_nelements(src1) == ne10) {
|
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
|
|
497
614
|
case GGML_OP_SCALE:
|
498
615
|
{
|
499
616
|
if (encoder == nil) {
|
500
|
-
encoder = [command_buffer
|
617
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
501
618
|
}
|
502
619
|
|
503
620
|
const float scale = *(const float *) src1->data;
|
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
|
|
511
628
|
|
512
629
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
513
630
|
} break;
|
514
|
-
case
|
515
|
-
{
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
631
|
+
case GGML_OP_UNARY:
|
632
|
+
switch (ggml_get_unary_op(gf->nodes[i])) {
|
633
|
+
case GGML_UNARY_OP_SILU:
|
634
|
+
{
|
635
|
+
if (encoder == nil) {
|
636
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
637
|
+
}
|
638
|
+
|
639
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
640
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
641
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
642
|
+
|
643
|
+
const int64_t n = ggml_nelements(dst);
|
644
|
+
|
645
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
646
|
+
} break;
|
647
|
+
case GGML_UNARY_OP_RELU:
|
648
|
+
{
|
649
|
+
if (encoder == nil) {
|
650
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
651
|
+
}
|
652
|
+
|
653
|
+
[encoder setComputePipelineState:ctx->pipeline_relu];
|
654
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
655
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
656
|
+
|
657
|
+
const int64_t n = ggml_nelements(dst);
|
658
|
+
|
659
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
660
|
+
} break;
|
661
|
+
case GGML_UNARY_OP_GELU:
|
662
|
+
{
|
663
|
+
if (encoder == nil) {
|
664
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
665
|
+
}
|
666
|
+
|
667
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
668
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
669
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
670
|
+
|
671
|
+
const int64_t n = ggml_nelements(dst);
|
672
|
+
|
673
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
674
|
+
} break;
|
675
|
+
default:
|
676
|
+
{
|
677
|
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
678
|
+
GGML_ASSERT(false);
|
679
|
+
}
|
541
680
|
} break;
|
542
|
-
case GGML_OP_GELU:
|
543
|
-
{
|
544
|
-
if (encoder == nil) {
|
545
|
-
encoder = [command_buffer computeCommandEncoder];
|
546
|
-
}
|
547
|
-
|
548
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
549
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
550
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
551
|
-
|
552
|
-
const int64_t n = ggml_nelements(dst);
|
553
|
-
|
554
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
555
|
-
} break;
|
556
681
|
case GGML_OP_SOFT_MAX:
|
557
682
|
{
|
558
683
|
if (encoder == nil) {
|
559
|
-
encoder = [command_buffer
|
684
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
560
685
|
}
|
561
686
|
|
562
687
|
const int nth = 32;
|
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
|
|
574
699
|
case GGML_OP_DIAG_MASK_INF:
|
575
700
|
{
|
576
701
|
if (encoder == nil) {
|
577
|
-
encoder = [command_buffer
|
702
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
578
703
|
}
|
579
704
|
|
580
|
-
const int n_past = ((int32_t *)(
|
705
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
581
706
|
|
582
707
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
583
708
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -637,7 +762,7 @@ void ggml_metal_graph_compute(
|
|
637
762
|
}
|
638
763
|
} else {
|
639
764
|
if (encoder == nil) {
|
640
|
-
encoder = [command_buffer
|
765
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
641
766
|
}
|
642
767
|
|
643
768
|
int nth0 = 32;
|
@@ -676,8 +801,8 @@ void ggml_metal_graph_compute(
|
|
676
801
|
GGML_ASSERT(ne02 == 1);
|
677
802
|
GGML_ASSERT(ne12 == 1);
|
678
803
|
|
679
|
-
nth0 =
|
680
|
-
nth1 =
|
804
|
+
nth0 = 2;
|
805
|
+
nth1 = 32;
|
681
806
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
682
807
|
} break;
|
683
808
|
case GGML_TYPE_Q3_K:
|
@@ -685,8 +810,8 @@ void ggml_metal_graph_compute(
|
|
685
810
|
GGML_ASSERT(ne02 == 1);
|
686
811
|
GGML_ASSERT(ne12 == 1);
|
687
812
|
|
688
|
-
nth0 =
|
689
|
-
nth1 =
|
813
|
+
nth0 = 2;
|
814
|
+
nth1 = 32;
|
690
815
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
691
816
|
} break;
|
692
817
|
case GGML_TYPE_Q4_K:
|
@@ -694,8 +819,8 @@ void ggml_metal_graph_compute(
|
|
694
819
|
GGML_ASSERT(ne02 == 1);
|
695
820
|
GGML_ASSERT(ne12 == 1);
|
696
821
|
|
697
|
-
nth0 =
|
698
|
-
nth1 =
|
822
|
+
nth0 = 2;
|
823
|
+
nth1 = 32;
|
699
824
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
700
825
|
} break;
|
701
826
|
case GGML_TYPE_Q5_K:
|
@@ -703,8 +828,8 @@ void ggml_metal_graph_compute(
|
|
703
828
|
GGML_ASSERT(ne02 == 1);
|
704
829
|
GGML_ASSERT(ne12 == 1);
|
705
830
|
|
706
|
-
nth0 =
|
707
|
-
nth1 =
|
831
|
+
nth0 = 2;
|
832
|
+
nth1 = 32;
|
708
833
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
709
834
|
} break;
|
710
835
|
case GGML_TYPE_Q6_K:
|
@@ -712,8 +837,8 @@ void ggml_metal_graph_compute(
|
|
712
837
|
GGML_ASSERT(ne02 == 1);
|
713
838
|
GGML_ASSERT(ne12 == 1);
|
714
839
|
|
715
|
-
nth0 =
|
716
|
-
nth1 =
|
840
|
+
nth0 = 2;
|
841
|
+
nth1 = 32;
|
717
842
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
718
843
|
} break;
|
719
844
|
default:
|
@@ -739,20 +864,22 @@ void ggml_metal_graph_compute(
|
|
739
864
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
740
865
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
741
866
|
|
742
|
-
if (src0t == GGML_TYPE_Q4_0
|
743
|
-
|
867
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
868
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
869
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
870
|
+
}
|
871
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
872
|
+
#ifdef GGML_QKK_64
|
873
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
874
|
+
#else
|
875
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
876
|
+
#endif
|
744
877
|
}
|
745
|
-
else if (src0t ==
|
746
|
-
[encoder
|
747
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
878
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
879
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
748
880
|
}
|
749
|
-
else if (src0t ==
|
750
|
-
|
751
|
-
src0t == GGML_TYPE_Q4_K ||
|
752
|
-
src0t == GGML_TYPE_Q5_K ||
|
753
|
-
src0t == GGML_TYPE_Q6_K) {
|
754
|
-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
881
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
882
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
756
883
|
} else {
|
757
884
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
758
885
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -762,7 +889,7 @@ void ggml_metal_graph_compute(
|
|
762
889
|
case GGML_OP_GET_ROWS:
|
763
890
|
{
|
764
891
|
if (encoder == nil) {
|
765
|
-
encoder = [command_buffer
|
892
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
766
893
|
}
|
767
894
|
|
768
895
|
switch (src0->type) {
|
@@ -791,12 +918,13 @@ void ggml_metal_graph_compute(
|
|
791
918
|
case GGML_OP_RMS_NORM:
|
792
919
|
{
|
793
920
|
if (encoder == nil) {
|
794
|
-
encoder = [command_buffer
|
921
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
795
922
|
}
|
796
923
|
|
797
|
-
|
924
|
+
float eps;
|
925
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
798
926
|
|
799
|
-
const int nth =
|
927
|
+
const int nth = 512;
|
800
928
|
|
801
929
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
802
930
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -804,7 +932,7 @@ void ggml_metal_graph_compute(
|
|
804
932
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
805
933
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
806
934
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
807
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
935
|
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
808
936
|
|
809
937
|
const int64_t nrows = ggml_nrows(src0);
|
810
938
|
|
@@ -813,7 +941,7 @@ void ggml_metal_graph_compute(
|
|
813
941
|
case GGML_OP_NORM:
|
814
942
|
{
|
815
943
|
if (encoder == nil) {
|
816
|
-
encoder = [command_buffer
|
944
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
817
945
|
}
|
818
946
|
|
819
947
|
const float eps = 1e-5f;
|
@@ -835,14 +963,15 @@ void ggml_metal_graph_compute(
|
|
835
963
|
case GGML_OP_ALIBI:
|
836
964
|
{
|
837
965
|
if (encoder == nil) {
|
838
|
-
encoder = [command_buffer
|
966
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
839
967
|
}
|
840
968
|
|
841
969
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
842
970
|
|
843
|
-
const int
|
844
|
-
const int
|
845
|
-
|
971
|
+
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
972
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
973
|
+
float max_bias;
|
974
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
846
975
|
|
847
976
|
if (__builtin_popcount(n_head) != 1) {
|
848
977
|
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
@@ -877,43 +1006,51 @@ void ggml_metal_graph_compute(
|
|
877
1006
|
case GGML_OP_ROPE:
|
878
1007
|
{
|
879
1008
|
if (encoder == nil) {
|
880
|
-
encoder = [command_buffer
|
1009
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
881
1010
|
}
|
882
1011
|
|
883
|
-
const int
|
884
|
-
const int
|
1012
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
1013
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1014
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
885
1015
|
|
886
|
-
|
1016
|
+
float freq_base;
|
1017
|
+
float freq_scale;
|
1018
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1019
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
887
1020
|
|
888
1021
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
889
1022
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
890
1023
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
891
|
-
[encoder setBytes:&ne00
|
892
|
-
[encoder setBytes:&ne01
|
893
|
-
[encoder setBytes:&ne02
|
894
|
-
[encoder setBytes:&ne03
|
895
|
-
[encoder setBytes:&nb00
|
896
|
-
[encoder setBytes:&nb01
|
897
|
-
[encoder setBytes:&nb02
|
898
|
-
[encoder setBytes:&nb03
|
899
|
-
[encoder setBytes:&ne0
|
900
|
-
[encoder setBytes:&ne1
|
901
|
-
[encoder setBytes:&ne2
|
902
|
-
[encoder setBytes:&ne3
|
903
|
-
[encoder setBytes:&nb0
|
904
|
-
[encoder setBytes:&nb1
|
905
|
-
[encoder setBytes:&nb2
|
906
|
-
[encoder setBytes:&nb3
|
907
|
-
[encoder setBytes:&n_past
|
908
|
-
[encoder setBytes:&n_dims
|
909
|
-
[encoder setBytes:&mode
|
1024
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1025
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1026
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1027
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1028
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1029
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1030
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1031
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1032
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1033
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1034
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1035
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1036
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1037
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1038
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1039
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1040
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
1041
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
1042
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
1043
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
1044
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
910
1045
|
|
911
1046
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
912
1047
|
} break;
|
1048
|
+
case GGML_OP_DUP:
|
913
1049
|
case GGML_OP_CPY:
|
1050
|
+
case GGML_OP_CONT:
|
914
1051
|
{
|
915
1052
|
if (encoder == nil) {
|
916
|
-
encoder = [command_buffer
|
1053
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
917
1054
|
}
|
918
1055
|
|
919
1056
|
const int nth = 32;
|
@@ -960,8 +1097,10 @@ void ggml_metal_graph_compute(
|
|
960
1097
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
961
1098
|
} break;
|
962
1099
|
default:
|
963
|
-
|
964
|
-
|
1100
|
+
{
|
1101
|
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1102
|
+
GGML_ASSERT(false);
|
1103
|
+
}
|
965
1104
|
}
|
966
1105
|
}
|
967
1106
|
|