llama_cpp 0.3.6 → 0.3.8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +8 -0
- data/ext/llama_cpp/src/ggml-alloc.c +44 -6
- data/ext/llama_cpp/src/ggml-alloc.h +4 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1398 -702
- data/ext/llama_cpp/src/ggml-cuda.h +19 -23
- data/ext/llama_cpp/src/ggml-metal.h +6 -3
- data/ext/llama_cpp/src/ggml-metal.m +112 -146
- data/ext/llama_cpp/src/ggml-metal.metal +471 -498
- data/ext/llama_cpp/src/ggml.c +396 -150
- data/ext/llama_cpp/src/ggml.h +113 -32
- data/ext/llama_cpp/src/llama-util.h +51 -9
- data/ext/llama_cpp/src/llama.cpp +390 -210
- data/ext/llama_cpp/src/llama.h +20 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -8,29 +8,25 @@ extern "C" {
|
|
8
8
|
|
9
9
|
#define GGML_CUDA_MAX_DEVICES 16
|
10
10
|
|
11
|
-
void ggml_init_cublas(void);
|
12
|
-
void
|
13
|
-
|
14
|
-
|
15
|
-
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
16
|
-
|
17
|
-
void
|
18
|
-
|
19
|
-
|
20
|
-
void *
|
21
|
-
void
|
22
|
-
|
23
|
-
void
|
24
|
-
|
25
|
-
void
|
26
|
-
|
27
|
-
|
28
|
-
void
|
29
|
-
void
|
30
|
-
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
31
|
-
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
32
|
-
void ggml_cuda_free_scratch(void);
|
33
|
-
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
11
|
+
GGML_API void ggml_init_cublas(void);
|
12
|
+
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
13
|
+
GGML_API void ggml_cuda_host_free(void * ptr);
|
14
|
+
|
15
|
+
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
16
|
+
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
17
|
+
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
18
|
+
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
19
|
+
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
20
|
+
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
21
|
+
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
22
|
+
GGML_API void ggml_cuda_set_main_device(int main_device);
|
23
|
+
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
24
|
+
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
25
|
+
GGML_API void ggml_cuda_free_scratch(void);
|
26
|
+
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
27
|
+
|
28
|
+
GGML_API int ggml_cuda_get_device_count(void);
|
29
|
+
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
34
30
|
|
35
31
|
#ifdef __cplusplus
|
36
32
|
}
|
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
|
63
63
|
|
64
64
|
// try to find operations that can be run concurrently in the graph
|
65
65
|
// you should run it again if the topology of your graph changes
|
66
|
-
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
66
|
+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
|
67
67
|
|
68
|
-
// if the graph has been optimized for concurrently dispatch
|
69
|
-
|
68
|
+
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
|
69
|
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
70
|
+
|
71
|
+
// output the concur_list for ggml_alloc
|
72
|
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
70
73
|
|
71
74
|
// same as ggml_graph_compute but uses Metal
|
72
75
|
// creates gf->n_threads command buffers in parallel
|
@@ -5,7 +5,11 @@
|
|
5
5
|
#import <Foundation/Foundation.h>
|
6
6
|
|
7
7
|
#import <Metal/Metal.h>
|
8
|
-
|
8
|
+
|
9
|
+
#undef MIN
|
10
|
+
#undef MAX
|
11
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
9
13
|
|
10
14
|
#ifdef GGML_METAL_NDEBUG
|
11
15
|
#define metal_printf(...)
|
@@ -15,6 +19,8 @@
|
|
15
19
|
|
16
20
|
#define UNUSED(x) (void)(x)
|
17
21
|
|
22
|
+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
23
|
+
|
18
24
|
struct ggml_metal_buffer {
|
19
25
|
const char * name;
|
20
26
|
|
@@ -36,7 +42,7 @@ struct ggml_metal_context {
|
|
36
42
|
int n_buffers;
|
37
43
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
38
44
|
|
39
|
-
int concur_list[
|
45
|
+
int concur_list[GGML_MAX_CONCUR];
|
40
46
|
int concur_list_len;
|
41
47
|
|
42
48
|
// custom kernels
|
@@ -72,6 +78,14 @@ struct ggml_metal_context {
|
|
72
78
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
73
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
74
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
87
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
88
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
75
89
|
GGML_METAL_DECL_KERNEL(rope);
|
76
90
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
77
91
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
@@ -103,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
103
117
|
ctx->n_buffers = 0;
|
104
118
|
ctx->concur_list_len = 0;
|
105
119
|
|
106
|
-
// determine if we can use MPS
|
107
|
-
if (MPSSupportsMTLDevice(ctx->device)) {
|
108
|
-
fprintf(stderr, "%s: using MPS\n", __func__);
|
109
|
-
} else {
|
110
|
-
fprintf(stderr, "%s: not using MPS\n", __func__);
|
111
|
-
GGML_ASSERT(false && "MPS not supported");
|
112
|
-
}
|
113
120
|
|
114
121
|
#if 0
|
115
122
|
// compile from source string and show compile log
|
@@ -119,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
119
126
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
120
127
|
if (error) {
|
121
128
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
122
|
-
|
129
|
+
return NULL;
|
123
130
|
}
|
124
131
|
}
|
125
132
|
#else
|
@@ -137,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
137
144
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
138
145
|
if (error) {
|
139
146
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
140
|
-
|
147
|
+
return NULL;
|
141
148
|
}
|
142
149
|
|
143
150
|
#ifdef GGML_QKK_64
|
@@ -149,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
149
156
|
#endif
|
150
157
|
if (error) {
|
151
158
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
152
|
-
|
159
|
+
return NULL;
|
153
160
|
}
|
154
161
|
}
|
155
162
|
#endif
|
156
163
|
|
157
164
|
// load kernels
|
158
165
|
{
|
166
|
+
NSError * error = nil;
|
159
167
|
#define GGML_METAL_ADD_KERNEL(name) \
|
160
168
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
161
|
-
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error
|
162
|
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
169
|
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
+
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
|
171
|
+
if (error) { \
|
172
|
+
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
|
+
return NULL; \
|
174
|
+
}
|
163
175
|
|
164
176
|
GGML_METAL_ADD_KERNEL(add);
|
165
177
|
GGML_METAL_ADD_KERNEL(add_row);
|
@@ -189,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
189
201
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
190
202
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
191
203
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
206
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
209
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
210
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
211
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
192
212
|
GGML_METAL_ADD_KERNEL(rope);
|
193
213
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
194
214
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
@@ -221,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
221
241
|
ctx->n_cb = n_cb;
|
222
242
|
}
|
223
243
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
244
|
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
245
|
+
return ctx->concur_list_len;
|
246
|
+
}
|
247
|
+
|
248
|
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
249
|
+
return ctx->concur_list;
|
229
250
|
}
|
230
251
|
|
231
252
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
@@ -368,17 +389,17 @@ void ggml_metal_get_tensor(
|
|
368
389
|
|
369
390
|
void ggml_metal_graph_find_concurrency(
|
370
391
|
struct ggml_metal_context * ctx,
|
371
|
-
struct ggml_cgraph * gf) {
|
392
|
+
struct ggml_cgraph * gf, bool check_mem) {
|
372
393
|
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
373
|
-
int nodes_unused[
|
394
|
+
int nodes_unused[GGML_MAX_CONCUR];
|
374
395
|
|
375
|
-
for (int i = 0; i <
|
376
|
-
for (int i = 0; i < gf->n_nodes;
|
396
|
+
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
397
|
+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
377
398
|
ctx->concur_list_len = 0;
|
378
399
|
|
379
|
-
int n_left
|
380
|
-
int n_start
|
381
|
-
int level_pos = 0;
|
400
|
+
int n_left = gf->n_nodes;
|
401
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
402
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
382
403
|
|
383
404
|
while (n_left > 0) {
|
384
405
|
// number of nodes at a layer (that can be issued concurrently)
|
@@ -386,28 +407,40 @@ void ggml_metal_graph_find_concurrency(
|
|
386
407
|
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
387
408
|
if (nodes_unused[i]) {
|
388
409
|
// if the requirements for gf->nodes[i] are satisfied
|
389
|
-
int exe_flag=1;
|
410
|
+
int exe_flag = 1;
|
411
|
+
|
390
412
|
// scan all srcs
|
391
413
|
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
392
414
|
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
393
415
|
if (src_cur) {
|
394
416
|
// if is leaf nodes it's satisfied.
|
395
|
-
|
417
|
+
// TODO: ggml_is_leaf()
|
418
|
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
419
|
+
continue;
|
420
|
+
}
|
396
421
|
|
397
422
|
// otherwise this src should be the output from previous nodes.
|
398
423
|
int is_found = 0;
|
424
|
+
|
399
425
|
// scan 2*search_depth back because we inserted barrier.
|
400
|
-
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
401
|
-
|
426
|
+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
427
|
+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
428
|
+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
429
|
+
is_found = 1;
|
430
|
+
break;
|
431
|
+
}
|
432
|
+
}
|
433
|
+
if (is_found == 0) {
|
434
|
+
exe_flag = 0;
|
435
|
+
break;
|
402
436
|
}
|
403
|
-
if (is_found == 0) {exe_flag = 0; break;}
|
404
437
|
}
|
405
438
|
}
|
406
|
-
if (exe_flag) {
|
439
|
+
if (exe_flag && check_mem) {
|
407
440
|
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
408
441
|
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
409
442
|
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
410
|
-
int64_t length
|
443
|
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
411
444
|
for (int j = n_start; j < i; j++) {
|
412
445
|
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
413
446
|
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
@@ -416,9 +449,9 @@ void ggml_metal_graph_find_concurrency(
|
|
416
449
|
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
417
450
|
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
418
451
|
continue;
|
419
|
-
} else {
|
420
|
-
exe_flag = 0;
|
421
452
|
}
|
453
|
+
|
454
|
+
exe_flag = 0;
|
422
455
|
}
|
423
456
|
}
|
424
457
|
}
|
@@ -435,11 +468,13 @@ void ggml_metal_graph_find_concurrency(
|
|
435
468
|
ctx->concur_list[level_pos + concurrency] = -1;
|
436
469
|
ctx->concur_list_len++;
|
437
470
|
// jump all sorted nodes at nodes_bak
|
438
|
-
while (!nodes_unused[n_start]) {
|
471
|
+
while (!nodes_unused[n_start]) {
|
472
|
+
n_start++;
|
473
|
+
}
|
439
474
|
level_pos += concurrency + 1;
|
440
475
|
}
|
441
476
|
|
442
|
-
if (ctx->concur_list_len >
|
477
|
+
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
443
478
|
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
444
479
|
}
|
445
480
|
}
|
@@ -453,7 +488,7 @@ void ggml_metal_graph_compute(
|
|
453
488
|
// else fallback to serial dispatch
|
454
489
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
455
490
|
|
456
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <=
|
491
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
457
492
|
|
458
493
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
459
494
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
@@ -485,7 +520,7 @@ void ggml_metal_graph_compute(
|
|
485
520
|
|
486
521
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
487
522
|
|
488
|
-
id<MTLComputeCommandEncoder> encoder =
|
523
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
489
524
|
|
490
525
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
491
526
|
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
@@ -494,10 +529,6 @@ void ggml_metal_graph_compute(
|
|
494
529
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
495
530
|
|
496
531
|
if (i == -1) {
|
497
|
-
if (encoder == nil) {
|
498
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
499
|
-
continue;
|
500
|
-
}
|
501
532
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
502
533
|
continue;
|
503
534
|
}
|
@@ -571,10 +602,6 @@ void ggml_metal_graph_compute(
|
|
571
602
|
} break;
|
572
603
|
case GGML_OP_ADD:
|
573
604
|
{
|
574
|
-
if (encoder == nil) {
|
575
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
576
|
-
}
|
577
|
-
|
578
605
|
if (ggml_nelements(src1) == ne10) {
|
579
606
|
// src1 is a row
|
580
607
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -592,10 +619,6 @@ void ggml_metal_graph_compute(
|
|
592
619
|
} break;
|
593
620
|
case GGML_OP_MUL:
|
594
621
|
{
|
595
|
-
if (encoder == nil) {
|
596
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
597
|
-
}
|
598
|
-
|
599
622
|
if (ggml_nelements(src1) == ne10) {
|
600
623
|
// src1 is a row
|
601
624
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -613,10 +636,6 @@ void ggml_metal_graph_compute(
|
|
613
636
|
} break;
|
614
637
|
case GGML_OP_SCALE:
|
615
638
|
{
|
616
|
-
if (encoder == nil) {
|
617
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
618
|
-
}
|
619
|
-
|
620
639
|
const float scale = *(const float *) src1->data;
|
621
640
|
|
622
641
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -632,10 +651,6 @@ void ggml_metal_graph_compute(
|
|
632
651
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
633
652
|
case GGML_UNARY_OP_SILU:
|
634
653
|
{
|
635
|
-
if (encoder == nil) {
|
636
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
637
|
-
}
|
638
|
-
|
639
654
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
640
655
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
641
656
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -646,10 +661,6 @@ void ggml_metal_graph_compute(
|
|
646
661
|
} break;
|
647
662
|
case GGML_UNARY_OP_RELU:
|
648
663
|
{
|
649
|
-
if (encoder == nil) {
|
650
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
651
|
-
}
|
652
|
-
|
653
664
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
654
665
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
655
666
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -660,10 +671,6 @@ void ggml_metal_graph_compute(
|
|
660
671
|
} break;
|
661
672
|
case GGML_UNARY_OP_GELU:
|
662
673
|
{
|
663
|
-
if (encoder == nil) {
|
664
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
665
|
-
}
|
666
|
-
|
667
674
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
668
675
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
669
676
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -680,10 +687,6 @@ void ggml_metal_graph_compute(
|
|
680
687
|
} break;
|
681
688
|
case GGML_OP_SOFT_MAX:
|
682
689
|
{
|
683
|
-
if (encoder == nil) {
|
684
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
685
|
-
}
|
686
|
-
|
687
690
|
const int nth = 32;
|
688
691
|
|
689
692
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
@@ -698,10 +701,6 @@ void ggml_metal_graph_compute(
|
|
698
701
|
} break;
|
699
702
|
case GGML_OP_DIAG_MASK_INF:
|
700
703
|
{
|
701
|
-
if (encoder == nil) {
|
702
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
703
|
-
}
|
704
|
-
|
705
704
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
706
705
|
|
707
706
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
@@ -719,53 +718,43 @@ void ggml_metal_graph_compute(
|
|
719
718
|
|
720
719
|
GGML_ASSERT(ne00 == ne10);
|
721
720
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
721
|
+
uint gqa = ne12/ne02;
|
722
722
|
GGML_ASSERT(ne03 == ne13);
|
723
723
|
|
724
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
725
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
724
726
|
if (ggml_is_contiguous(src0) &&
|
725
727
|
ggml_is_contiguous(src1) &&
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
755
|
-
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
756
|
-
size_t offs_dst_cur = offs_dst + i02*nb2;
|
757
|
-
|
758
|
-
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
759
|
-
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
760
|
-
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
761
|
-
|
762
|
-
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
763
|
-
}
|
764
|
-
} else {
|
765
|
-
if (encoder == nil) {
|
766
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
728
|
+
src1t == GGML_TYPE_F32 &&
|
729
|
+
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
|
+
ne00%32 == 0 &&
|
731
|
+
ne11 > 1) {
|
732
|
+
switch (src0->type) {
|
733
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
734
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
735
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
736
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
737
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
738
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
739
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
740
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
741
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
742
|
+
}
|
743
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
767
756
|
}
|
768
|
-
|
757
|
+
else {
|
769
758
|
int nth0 = 32;
|
770
759
|
int nth1 = 1;
|
771
760
|
|
@@ -864,23 +853,24 @@ void ggml_metal_graph_compute(
|
|
864
853
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
865
854
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
866
855
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
867
857
|
|
868
858
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
869
859
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11,
|
860
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
861
|
}
|
872
862
|
else if (src0t == GGML_TYPE_Q3_K) {
|
873
863
|
#ifdef GGML_QKK_64
|
874
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
864
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
875
865
|
#else
|
876
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11,
|
866
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
877
867
|
#endif
|
878
868
|
}
|
879
869
|
else if (src0t == GGML_TYPE_Q5_K) {
|
880
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11,
|
870
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
881
871
|
}
|
882
872
|
else if (src0t == GGML_TYPE_Q6_K) {
|
883
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
873
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
884
874
|
} else {
|
885
875
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
886
876
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -889,10 +879,6 @@ void ggml_metal_graph_compute(
|
|
889
879
|
} break;
|
890
880
|
case GGML_OP_GET_ROWS:
|
891
881
|
{
|
892
|
-
if (encoder == nil) {
|
893
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
894
|
-
}
|
895
|
-
|
896
882
|
switch (src0->type) {
|
897
883
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
898
884
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
@@ -918,10 +904,6 @@ void ggml_metal_graph_compute(
|
|
918
904
|
} break;
|
919
905
|
case GGML_OP_RMS_NORM:
|
920
906
|
{
|
921
|
-
if (encoder == nil) {
|
922
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
923
|
-
}
|
924
|
-
|
925
907
|
float eps;
|
926
908
|
memcpy(&eps, dst->op_params, sizeof(float));
|
927
909
|
|
@@ -941,10 +923,6 @@ void ggml_metal_graph_compute(
|
|
941
923
|
} break;
|
942
924
|
case GGML_OP_NORM:
|
943
925
|
{
|
944
|
-
if (encoder == nil) {
|
945
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
946
|
-
}
|
947
|
-
|
948
926
|
const float eps = 1e-5f;
|
949
927
|
|
950
928
|
const int nth = 256;
|
@@ -963,10 +941,6 @@ void ggml_metal_graph_compute(
|
|
963
941
|
} break;
|
964
942
|
case GGML_OP_ALIBI:
|
965
943
|
{
|
966
|
-
if (encoder == nil) {
|
967
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
968
|
-
}
|
969
|
-
|
970
944
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
971
945
|
|
972
946
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
@@ -1006,10 +980,6 @@ void ggml_metal_graph_compute(
|
|
1006
980
|
} break;
|
1007
981
|
case GGML_OP_ROPE:
|
1008
982
|
{
|
1009
|
-
if (encoder == nil) {
|
1010
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1011
|
-
}
|
1012
|
-
|
1013
983
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1014
984
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1015
985
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1050,10 +1020,6 @@ void ggml_metal_graph_compute(
|
|
1050
1020
|
case GGML_OP_CPY:
|
1051
1021
|
case GGML_OP_CONT:
|
1052
1022
|
{
|
1053
|
-
if (encoder == nil) {
|
1054
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1055
|
-
}
|
1056
|
-
|
1057
1023
|
const int nth = 32;
|
1058
1024
|
|
1059
1025
|
switch (src0t) {
|