llama_cpp 0.3.4 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +18 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +315 -8
- data/ext/llama_cpp/src/ggml-alloc.c +541 -0
- data/ext/llama_cpp/src/ggml-alloc.h +22 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +2271 -414
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +218 -87
- data/ext/llama_cpp/src/ggml-metal.metal +72 -55
- data/ext/llama_cpp/src/ggml.c +754 -996
- data/ext/llama_cpp/src/ggml.h +94 -18
- data/ext/llama_cpp/src/k_quants.c +350 -24
- data/ext/llama_cpp/src/llama.cpp +713 -179
- data/ext/llama_cpp/src/llama.h +61 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +26 -0
- metadata +4 -2
@@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
|
27
27
|
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
28
28
|
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
29
29
|
void ggml_cuda_set_main_device(int main_device);
|
30
|
+
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
30
31
|
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
31
32
|
void ggml_cuda_free_scratch(void);
|
32
33
|
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
|
61
61
|
// get data from the device into host memory
|
62
62
|
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
63
63
|
|
64
|
+
// try to find operations that can be run concurrently in the graph
|
65
|
+
// you should run it again if the topology of your graph changes
|
66
|
+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
67
|
+
|
68
|
+
// if the graph has been optimized for concurrently dispatch
|
69
|
+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
70
|
+
|
64
71
|
// same as ggml_graph_compute but uses Metal
|
65
72
|
// creates gf->n_threads command buffers in parallel
|
66
73
|
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
@@ -36,12 +36,16 @@ struct ggml_metal_context {
|
|
36
36
|
int n_buffers;
|
37
37
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
38
38
|
|
39
|
+
int concur_list[GGML_MAX_NODES];
|
40
|
+
int concur_list_len;
|
41
|
+
|
39
42
|
// custom kernels
|
40
43
|
#define GGML_METAL_DECL_KERNEL(name) \
|
41
44
|
id<MTLFunction> function_##name; \
|
42
45
|
id<MTLComputePipelineState> pipeline_##name
|
43
46
|
|
44
47
|
GGML_METAL_DECL_KERNEL(add);
|
48
|
+
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
45
49
|
GGML_METAL_DECL_KERNEL(mul);
|
46
50
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
47
51
|
GGML_METAL_DECL_KERNEL(scale);
|
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
97
101
|
ctx->device = MTLCreateSystemDefaultDevice();
|
98
102
|
ctx->queue = [ctx->device newCommandQueue];
|
99
103
|
ctx->n_buffers = 0;
|
104
|
+
ctx->concur_list_len = 0;
|
100
105
|
|
101
106
|
// determine if we can use MPS
|
102
107
|
if (MPSSupportsMTLDevice(ctx->device)) {
|
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
157
162
|
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
158
163
|
|
159
164
|
GGML_METAL_ADD_KERNEL(add);
|
165
|
+
GGML_METAL_ADD_KERNEL(add_row);
|
160
166
|
GGML_METAL_ADD_KERNEL(mul);
|
161
167
|
GGML_METAL_ADD_KERNEL(mul_row);
|
162
168
|
GGML_METAL_ADD_KERNEL(scale);
|
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
215
221
|
ctx->n_cb = n_cb;
|
216
222
|
}
|
217
223
|
|
224
|
+
bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
225
|
+
if (ctx->concur_list_len) {
|
226
|
+
return true;
|
227
|
+
}
|
228
|
+
return false;
|
229
|
+
}
|
230
|
+
|
218
231
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
219
232
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
220
233
|
// Metal buffer based on the host memory pointer
|
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
|
|
353
366
|
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
354
367
|
}
|
355
368
|
|
369
|
+
void ggml_metal_graph_find_concurrency(
|
370
|
+
struct ggml_metal_context * ctx,
|
371
|
+
struct ggml_cgraph * gf) {
|
372
|
+
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
373
|
+
int nodes_unused[GGML_MAX_NODES];
|
374
|
+
|
375
|
+
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
|
376
|
+
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
|
377
|
+
ctx->concur_list_len = 0;
|
378
|
+
|
379
|
+
int n_left = gf->n_nodes;
|
380
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
381
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
382
|
+
|
383
|
+
while (n_left > 0) {
|
384
|
+
// number of nodes at a layer (that can be issued concurrently)
|
385
|
+
int concurrency = 0;
|
386
|
+
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
387
|
+
if (nodes_unused[i]) {
|
388
|
+
// if the requirements for gf->nodes[i] are satisfied
|
389
|
+
int exe_flag=1;
|
390
|
+
// scan all srcs
|
391
|
+
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
392
|
+
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
393
|
+
if (src_cur) {
|
394
|
+
// if is leaf nodes it's satisfied.
|
395
|
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
|
396
|
+
|
397
|
+
// otherwise this src should be the output from previous nodes.
|
398
|
+
int is_found = 0;
|
399
|
+
// scan 2*search_depth back because we inserted barrier.
|
400
|
+
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
401
|
+
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
|
402
|
+
}
|
403
|
+
if (is_found == 0) {exe_flag = 0; break;}
|
404
|
+
}
|
405
|
+
}
|
406
|
+
if (exe_flag) {
|
407
|
+
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
408
|
+
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
409
|
+
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
410
|
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
411
|
+
for (int j = n_start; j < i; j++) {
|
412
|
+
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
413
|
+
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
414
|
+
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
415
|
+
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
416
|
+
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
417
|
+
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
418
|
+
continue;
|
419
|
+
} else {
|
420
|
+
exe_flag = 0;
|
421
|
+
}
|
422
|
+
}
|
423
|
+
}
|
424
|
+
}
|
425
|
+
if (exe_flag) {
|
426
|
+
ctx->concur_list[level_pos + concurrency] = i;
|
427
|
+
nodes_unused[i] = 0;
|
428
|
+
concurrency++;
|
429
|
+
ctx->concur_list_len++;
|
430
|
+
}
|
431
|
+
}
|
432
|
+
}
|
433
|
+
n_left -= concurrency;
|
434
|
+
// adding a barrier different layer
|
435
|
+
ctx->concur_list[level_pos + concurrency] = -1;
|
436
|
+
ctx->concur_list_len++;
|
437
|
+
// jump all sorted nodes at nodes_bak
|
438
|
+
while (!nodes_unused[n_start]) {n_start++;}
|
439
|
+
level_pos += concurrency + 1;
|
440
|
+
}
|
441
|
+
|
442
|
+
if (ctx->concur_list_len > GGML_MAX_NODES) {
|
443
|
+
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
444
|
+
}
|
445
|
+
}
|
446
|
+
|
356
447
|
void ggml_metal_graph_compute(
|
357
448
|
struct ggml_metal_context * ctx,
|
358
449
|
struct ggml_cgraph * gf) {
|
359
450
|
metal_printf("%s: evaluating graph\n", __func__);
|
360
451
|
|
452
|
+
// if there is ctx->concur_list, dispatch concurrently
|
453
|
+
// else fallback to serial dispatch
|
454
|
+
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
455
|
+
|
456
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
|
457
|
+
|
458
|
+
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
459
|
+
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
460
|
+
|
361
461
|
// create multiple command buffers and enqueue them
|
362
462
|
// then, we encode the graph into the command buffers in parallel
|
363
463
|
|
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
|
|
376
476
|
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
377
477
|
|
378
478
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
379
|
-
const int n_nodes_per_cb = (
|
479
|
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
380
480
|
|
381
481
|
dispatch_async(queue, ^{
|
382
482
|
size_t offs_src0 = 0;
|
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
|
|
387
487
|
|
388
488
|
id<MTLComputeCommandEncoder> encoder = nil;
|
389
489
|
|
390
|
-
const int node_start =
|
391
|
-
const int node_end = (cb_idx == n_cb - 1) ?
|
490
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
491
|
+
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
492
|
+
|
493
|
+
for (int ind = node_start; ind < node_end; ++ind) {
|
494
|
+
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
495
|
+
|
496
|
+
if (i == -1) {
|
497
|
+
if (encoder == nil) {
|
498
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
499
|
+
continue;
|
500
|
+
}
|
501
|
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
502
|
+
continue;
|
503
|
+
}
|
392
504
|
|
393
|
-
for (int i = node_start; i < node_end; ++i) {
|
394
505
|
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
395
506
|
|
396
507
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
|
|
461
572
|
case GGML_OP_ADD:
|
462
573
|
{
|
463
574
|
if (encoder == nil) {
|
464
|
-
encoder = [command_buffer
|
575
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
465
576
|
}
|
466
577
|
|
467
|
-
|
578
|
+
if (ggml_nelements(src1) == ne10) {
|
579
|
+
// src1 is a row
|
580
|
+
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
581
|
+
} else {
|
582
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
583
|
+
}
|
468
584
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
469
585
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
470
586
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
587
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
471
588
|
|
472
589
|
const int64_t n = ggml_nelements(dst);
|
473
590
|
|
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
|
|
476
593
|
case GGML_OP_MUL:
|
477
594
|
{
|
478
595
|
if (encoder == nil) {
|
479
|
-
encoder = [command_buffer
|
596
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
480
597
|
}
|
481
598
|
|
482
599
|
if (ggml_nelements(src1) == ne10) {
|
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
|
|
497
614
|
case GGML_OP_SCALE:
|
498
615
|
{
|
499
616
|
if (encoder == nil) {
|
500
|
-
encoder = [command_buffer
|
617
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
501
618
|
}
|
502
619
|
|
503
620
|
const float scale = *(const float *) src1->data;
|
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
|
|
511
628
|
|
512
629
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
513
630
|
} break;
|
514
|
-
case
|
515
|
-
{
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
631
|
+
case GGML_OP_UNARY:
|
632
|
+
switch (ggml_get_unary_op(gf->nodes[i])) {
|
633
|
+
case GGML_UNARY_OP_SILU:
|
634
|
+
{
|
635
|
+
if (encoder == nil) {
|
636
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
637
|
+
}
|
638
|
+
|
639
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
640
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
641
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
642
|
+
|
643
|
+
const int64_t n = ggml_nelements(dst);
|
644
|
+
|
645
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
646
|
+
} break;
|
647
|
+
case GGML_UNARY_OP_RELU:
|
648
|
+
{
|
649
|
+
if (encoder == nil) {
|
650
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
651
|
+
}
|
652
|
+
|
653
|
+
[encoder setComputePipelineState:ctx->pipeline_relu];
|
654
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
655
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
656
|
+
|
657
|
+
const int64_t n = ggml_nelements(dst);
|
658
|
+
|
659
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
660
|
+
} break;
|
661
|
+
case GGML_UNARY_OP_GELU:
|
662
|
+
{
|
663
|
+
if (encoder == nil) {
|
664
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
665
|
+
}
|
666
|
+
|
667
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
668
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
669
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
670
|
+
|
671
|
+
const int64_t n = ggml_nelements(dst);
|
672
|
+
|
673
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
674
|
+
} break;
|
675
|
+
default:
|
676
|
+
{
|
677
|
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
678
|
+
GGML_ASSERT(false);
|
679
|
+
}
|
541
680
|
} break;
|
542
|
-
case GGML_OP_GELU:
|
543
|
-
{
|
544
|
-
if (encoder == nil) {
|
545
|
-
encoder = [command_buffer computeCommandEncoder];
|
546
|
-
}
|
547
|
-
|
548
|
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
549
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
550
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
551
|
-
|
552
|
-
const int64_t n = ggml_nelements(dst);
|
553
|
-
|
554
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
555
|
-
} break;
|
556
681
|
case GGML_OP_SOFT_MAX:
|
557
682
|
{
|
558
683
|
if (encoder == nil) {
|
559
|
-
encoder = [command_buffer
|
684
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
560
685
|
}
|
561
686
|
|
562
687
|
const int nth = 32;
|
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
|
|
574
699
|
case GGML_OP_DIAG_MASK_INF:
|
575
700
|
{
|
576
701
|
if (encoder == nil) {
|
577
|
-
encoder = [command_buffer
|
702
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
578
703
|
}
|
579
704
|
|
580
|
-
const int n_past = ((int32_t *)(
|
705
|
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
581
706
|
|
582
707
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
583
708
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -593,7 +718,8 @@ void ggml_metal_graph_compute(
|
|
593
718
|
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
594
719
|
|
595
720
|
GGML_ASSERT(ne00 == ne10);
|
596
|
-
GGML_ASSERT(ne02 == ne12);
|
721
|
+
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
722
|
+
GGML_ASSERT(ne03 == ne13);
|
597
723
|
|
598
724
|
if (ggml_is_contiguous(src0) &&
|
599
725
|
ggml_is_contiguous(src1) &&
|
@@ -621,11 +747,11 @@ void ggml_metal_graph_compute(
|
|
621
747
|
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
622
748
|
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
623
749
|
|
624
|
-
// we need to do
|
750
|
+
// we need to do ne12 multiplications
|
625
751
|
// TODO: is there a way to do this in parallel - currently very slow ..
|
626
752
|
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
627
|
-
for (int64_t i02 = 0; i02 <
|
628
|
-
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
753
|
+
for (int64_t i02 = 0; i02 < ne12; ++i02) {
|
754
|
+
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
629
755
|
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
630
756
|
size_t offs_dst_cur = offs_dst + i02*nb2;
|
631
757
|
|
@@ -637,7 +763,7 @@ void ggml_metal_graph_compute(
|
|
637
763
|
}
|
638
764
|
} else {
|
639
765
|
if (encoder == nil) {
|
640
|
-
encoder = [command_buffer
|
766
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
641
767
|
}
|
642
768
|
|
643
769
|
int nth0 = 32;
|
@@ -647,8 +773,6 @@ void ggml_metal_graph_compute(
|
|
647
773
|
switch (src0t) {
|
648
774
|
case GGML_TYPE_F16:
|
649
775
|
{
|
650
|
-
GGML_ASSERT(ne02 == ne12);
|
651
|
-
|
652
776
|
nth0 = 64;
|
653
777
|
nth1 = 1;
|
654
778
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
@@ -728,16 +852,18 @@ void ggml_metal_graph_compute(
|
|
728
852
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
729
853
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
730
854
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
731
|
-
[encoder setBytes:&
|
732
|
-
[encoder setBytes:&
|
733
|
-
[encoder setBytes:&
|
734
|
-
[encoder setBytes:&
|
735
|
-
[encoder setBytes:&
|
736
|
-
[encoder setBytes:&
|
737
|
-
[encoder setBytes:&
|
738
|
-
[encoder setBytes:&
|
739
|
-
[encoder setBytes:&
|
740
|
-
[encoder setBytes:&
|
855
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
856
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
857
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
858
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
859
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
860
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
861
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
862
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
863
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
864
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
865
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
866
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
741
867
|
|
742
868
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
743
869
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
@@ -764,7 +890,7 @@ void ggml_metal_graph_compute(
|
|
764
890
|
case GGML_OP_GET_ROWS:
|
765
891
|
{
|
766
892
|
if (encoder == nil) {
|
767
|
-
encoder = [command_buffer
|
893
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
768
894
|
}
|
769
895
|
|
770
896
|
switch (src0->type) {
|
@@ -793,10 +919,11 @@ void ggml_metal_graph_compute(
|
|
793
919
|
case GGML_OP_RMS_NORM:
|
794
920
|
{
|
795
921
|
if (encoder == nil) {
|
796
|
-
encoder = [command_buffer
|
922
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
797
923
|
}
|
798
924
|
|
799
|
-
|
925
|
+
float eps;
|
926
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
800
927
|
|
801
928
|
const int nth = 512;
|
802
929
|
|
@@ -815,7 +942,7 @@ void ggml_metal_graph_compute(
|
|
815
942
|
case GGML_OP_NORM:
|
816
943
|
{
|
817
944
|
if (encoder == nil) {
|
818
|
-
encoder = [command_buffer
|
945
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
819
946
|
}
|
820
947
|
|
821
948
|
const float eps = 1e-5f;
|
@@ -837,14 +964,15 @@ void ggml_metal_graph_compute(
|
|
837
964
|
case GGML_OP_ALIBI:
|
838
965
|
{
|
839
966
|
if (encoder == nil) {
|
840
|
-
encoder = [command_buffer
|
967
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
841
968
|
}
|
842
969
|
|
843
970
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
844
971
|
|
845
|
-
const int
|
846
|
-
const int
|
847
|
-
|
972
|
+
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
973
|
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
974
|
+
float max_bias;
|
975
|
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
848
976
|
|
849
977
|
if (__builtin_popcount(n_head) != 1) {
|
850
978
|
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
@@ -879,18 +1007,17 @@ void ggml_metal_graph_compute(
|
|
879
1007
|
case GGML_OP_ROPE:
|
880
1008
|
{
|
881
1009
|
if (encoder == nil) {
|
882
|
-
encoder = [command_buffer
|
1010
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
883
1011
|
}
|
884
1012
|
|
885
|
-
const int
|
886
|
-
const int
|
887
|
-
|
888
|
-
const int n_past = ((int32_t *)(src1->data))[0];
|
1013
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
1014
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1015
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
889
1016
|
|
890
1017
|
float freq_base;
|
891
1018
|
float freq_scale;
|
892
|
-
memcpy(&freq_base, (int32_t *)
|
893
|
-
memcpy(&freq_scale, (int32_t *)
|
1019
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
1020
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
894
1021
|
|
895
1022
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
896
1023
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -919,10 +1046,12 @@ void ggml_metal_graph_compute(
|
|
919
1046
|
|
920
1047
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
921
1048
|
} break;
|
1049
|
+
case GGML_OP_DUP:
|
922
1050
|
case GGML_OP_CPY:
|
1051
|
+
case GGML_OP_CONT:
|
923
1052
|
{
|
924
1053
|
if (encoder == nil) {
|
925
|
-
encoder = [command_buffer
|
1054
|
+
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
926
1055
|
}
|
927
1056
|
|
928
1057
|
const int nth = 32;
|
@@ -969,8 +1098,10 @@ void ggml_metal_graph_compute(
|
|
969
1098
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
970
1099
|
} break;
|
971
1100
|
default:
|
972
|
-
|
973
|
-
|
1101
|
+
{
|
1102
|
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
1103
|
+
GGML_ASSERT(false);
|
1104
|
+
}
|
974
1105
|
}
|
975
1106
|
}
|
976
1107
|
|