llama_cpp 0.3.6 → 0.3.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/llama_cpp/src/ggml-alloc.c +8 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1165 -721
- data/ext/llama_cpp/src/ggml-metal.m +39 -18
- data/ext/llama_cpp/src/ggml.c +396 -150
- data/ext/llama_cpp/src/ggml.h +113 -32
- data/ext/llama_cpp/src/llama-util.h +41 -1
- data/ext/llama_cpp/src/llama.cpp +214 -146
- data/ext/llama_cpp/src/llama.h +18 -1
- data/lib/llama_cpp/version.rb +2 -2
- metadata +2 -2
@@ -7,6 +7,11 @@
|
|
7
7
|
#import <Metal/Metal.h>
|
8
8
|
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
9
9
|
|
10
|
+
#undef MIN
|
11
|
+
#undef MAX
|
12
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
13
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
14
|
+
|
10
15
|
#ifdef GGML_METAL_NDEBUG
|
11
16
|
#define metal_printf(...)
|
12
17
|
#else
|
@@ -15,6 +20,8 @@
|
|
15
20
|
|
16
21
|
#define UNUSED(x) (void)(x)
|
17
22
|
|
23
|
+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
24
|
+
|
18
25
|
struct ggml_metal_buffer {
|
19
26
|
const char * name;
|
20
27
|
|
@@ -36,7 +43,7 @@ struct ggml_metal_context {
|
|
36
43
|
int n_buffers;
|
37
44
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
38
45
|
|
39
|
-
int concur_list[
|
46
|
+
int concur_list[GGML_MAX_CONCUR];
|
40
47
|
int concur_list_len;
|
41
48
|
|
42
49
|
// custom kernels
|
@@ -370,15 +377,15 @@ void ggml_metal_graph_find_concurrency(
|
|
370
377
|
struct ggml_metal_context * ctx,
|
371
378
|
struct ggml_cgraph * gf) {
|
372
379
|
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
373
|
-
int nodes_unused[
|
380
|
+
int nodes_unused[GGML_MAX_CONCUR];
|
374
381
|
|
375
|
-
for (int i = 0; i <
|
376
|
-
for (int i = 0; i < gf->n_nodes;
|
382
|
+
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
383
|
+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
377
384
|
ctx->concur_list_len = 0;
|
378
385
|
|
379
|
-
int n_left
|
380
|
-
int n_start
|
381
|
-
int level_pos = 0;
|
386
|
+
int n_left = gf->n_nodes;
|
387
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
388
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
382
389
|
|
383
390
|
while (n_left > 0) {
|
384
391
|
// number of nodes at a layer (that can be issued concurrently)
|
@@ -386,28 +393,40 @@ void ggml_metal_graph_find_concurrency(
|
|
386
393
|
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
387
394
|
if (nodes_unused[i]) {
|
388
395
|
// if the requirements for gf->nodes[i] are satisfied
|
389
|
-
int exe_flag=1;
|
396
|
+
int exe_flag = 1;
|
397
|
+
|
390
398
|
// scan all srcs
|
391
399
|
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
392
400
|
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
393
401
|
if (src_cur) {
|
394
402
|
// if is leaf nodes it's satisfied.
|
395
|
-
|
403
|
+
// TODO: ggml_is_leaf()
|
404
|
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
405
|
+
continue;
|
406
|
+
}
|
396
407
|
|
397
408
|
// otherwise this src should be the output from previous nodes.
|
398
409
|
int is_found = 0;
|
410
|
+
|
399
411
|
// scan 2*search_depth back because we inserted barrier.
|
400
|
-
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
401
|
-
|
412
|
+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
413
|
+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
414
|
+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
415
|
+
is_found = 1;
|
416
|
+
break;
|
417
|
+
}
|
418
|
+
}
|
419
|
+
if (is_found == 0) {
|
420
|
+
exe_flag = 0;
|
421
|
+
break;
|
402
422
|
}
|
403
|
-
if (is_found == 0) {exe_flag = 0; break;}
|
404
423
|
}
|
405
424
|
}
|
406
425
|
if (exe_flag) {
|
407
426
|
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
408
427
|
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
409
428
|
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
410
|
-
int64_t length
|
429
|
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
411
430
|
for (int j = n_start; j < i; j++) {
|
412
431
|
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
413
432
|
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
|
|
416
435
|
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
417
436
|
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
418
437
|
continue;
|
419
|
-
} else {
|
420
|
-
exe_flag = 0;
|
421
438
|
}
|
439
|
+
|
440
|
+
exe_flag = 0;
|
422
441
|
}
|
423
442
|
}
|
424
443
|
}
|
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
|
|
435
454
|
ctx->concur_list[level_pos + concurrency] = -1;
|
436
455
|
ctx->concur_list_len++;
|
437
456
|
// jump all sorted nodes at nodes_bak
|
438
|
-
while (!nodes_unused[n_start]) {
|
457
|
+
while (!nodes_unused[n_start]) {
|
458
|
+
n_start++;
|
459
|
+
}
|
439
460
|
level_pos += concurrency + 1;
|
440
461
|
}
|
441
462
|
|
442
|
-
if (ctx->concur_list_len >
|
463
|
+
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
443
464
|
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
444
465
|
}
|
445
466
|
}
|
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
|
|
453
474
|
// else fallback to serial dispatch
|
454
475
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
455
476
|
|
456
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <=
|
477
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
457
478
|
|
458
479
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
459
480
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|