llama_cpp 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +8 -0
- data/ext/llama_cpp/src/ggml-alloc.c +44 -6
- data/ext/llama_cpp/src/ggml-alloc.h +4 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1398 -702
- data/ext/llama_cpp/src/ggml-cuda.h +19 -23
- data/ext/llama_cpp/src/ggml-metal.h +6 -3
- data/ext/llama_cpp/src/ggml-metal.m +112 -146
- data/ext/llama_cpp/src/ggml-metal.metal +471 -498
- data/ext/llama_cpp/src/ggml.c +396 -150
- data/ext/llama_cpp/src/ggml.h +113 -32
- data/ext/llama_cpp/src/llama-util.h +51 -9
- data/ext/llama_cpp/src/llama.cpp +390 -210
- data/ext/llama_cpp/src/llama.h +20 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +1 -0
- metadata +2 -2
@@ -8,29 +8,25 @@ extern "C" {
|
|
8
8
|
|
9
9
|
#define GGML_CUDA_MAX_DEVICES 16
|
10
10
|
|
11
|
-
void ggml_init_cublas(void);
|
12
|
-
void
|
13
|
-
|
14
|
-
|
15
|
-
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
16
|
-
|
17
|
-
void
|
18
|
-
|
19
|
-
|
20
|
-
void *
|
21
|
-
void
|
22
|
-
|
23
|
-
void
|
24
|
-
|
25
|
-
void
|
26
|
-
|
27
|
-
|
28
|
-
void
|
29
|
-
void
|
30
|
-
void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
31
|
-
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
32
|
-
void ggml_cuda_free_scratch(void);
|
33
|
-
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
11
|
+
GGML_API void ggml_init_cublas(void);
|
12
|
+
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
13
|
+
GGML_API void ggml_cuda_host_free(void * ptr);
|
14
|
+
|
15
|
+
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
16
|
+
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
17
|
+
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
18
|
+
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
19
|
+
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
20
|
+
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
21
|
+
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
22
|
+
GGML_API void ggml_cuda_set_main_device(int main_device);
|
23
|
+
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
24
|
+
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
25
|
+
GGML_API void ggml_cuda_free_scratch(void);
|
26
|
+
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
27
|
+
|
28
|
+
GGML_API int ggml_cuda_get_device_count(void);
|
29
|
+
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
34
30
|
|
35
31
|
#ifdef __cplusplus
|
36
32
|
}
|
@@ -63,10 +63,13 @@ void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
|
63
63
|
|
64
64
|
// try to find operations that can be run concurrently in the graph
|
65
65
|
// you should run it again if the topology of your graph changes
|
66
|
-
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
66
|
+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
|
67
67
|
|
68
|
-
// if the graph has been optimized for concurrently dispatch
|
69
|
-
|
68
|
+
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
|
69
|
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
70
|
+
|
71
|
+
// output the concur_list for ggml_alloc
|
72
|
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
70
73
|
|
71
74
|
// same as ggml_graph_compute but uses Metal
|
72
75
|
// creates gf->n_threads command buffers in parallel
|
@@ -5,7 +5,11 @@
|
|
5
5
|
#import <Foundation/Foundation.h>
|
6
6
|
|
7
7
|
#import <Metal/Metal.h>
|
8
|
-
|
8
|
+
|
9
|
+
#undef MIN
|
10
|
+
#undef MAX
|
11
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
12
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
9
13
|
|
10
14
|
#ifdef GGML_METAL_NDEBUG
|
11
15
|
#define metal_printf(...)
|
@@ -15,6 +19,8 @@
|
|
15
19
|
|
16
20
|
#define UNUSED(x) (void)(x)
|
17
21
|
|
22
|
+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
23
|
+
|
18
24
|
struct ggml_metal_buffer {
|
19
25
|
const char * name;
|
20
26
|
|
@@ -36,7 +42,7 @@ struct ggml_metal_context {
|
|
36
42
|
int n_buffers;
|
37
43
|
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
38
44
|
|
39
|
-
int concur_list[
|
45
|
+
int concur_list[GGML_MAX_CONCUR];
|
40
46
|
int concur_list_len;
|
41
47
|
|
42
48
|
// custom kernels
|
@@ -72,6 +78,14 @@ struct ggml_metal_context {
|
|
72
78
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
73
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
74
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
87
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
88
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
75
89
|
GGML_METAL_DECL_KERNEL(rope);
|
76
90
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
77
91
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
@@ -103,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
103
117
|
ctx->n_buffers = 0;
|
104
118
|
ctx->concur_list_len = 0;
|
105
119
|
|
106
|
-
// determine if we can use MPS
|
107
|
-
if (MPSSupportsMTLDevice(ctx->device)) {
|
108
|
-
fprintf(stderr, "%s: using MPS\n", __func__);
|
109
|
-
} else {
|
110
|
-
fprintf(stderr, "%s: not using MPS\n", __func__);
|
111
|
-
GGML_ASSERT(false && "MPS not supported");
|
112
|
-
}
|
113
120
|
|
114
121
|
#if 0
|
115
122
|
// compile from source string and show compile log
|
@@ -119,7 +126,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
119
126
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
120
127
|
if (error) {
|
121
128
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
122
|
-
|
129
|
+
return NULL;
|
123
130
|
}
|
124
131
|
}
|
125
132
|
#else
|
@@ -137,7 +144,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
137
144
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
138
145
|
if (error) {
|
139
146
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
140
|
-
|
147
|
+
return NULL;
|
141
148
|
}
|
142
149
|
|
143
150
|
#ifdef GGML_QKK_64
|
@@ -149,17 +156,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
149
156
|
#endif
|
150
157
|
if (error) {
|
151
158
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
152
|
-
|
159
|
+
return NULL;
|
153
160
|
}
|
154
161
|
}
|
155
162
|
#endif
|
156
163
|
|
157
164
|
// load kernels
|
158
165
|
{
|
166
|
+
NSError * error = nil;
|
159
167
|
#define GGML_METAL_ADD_KERNEL(name) \
|
160
168
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
161
|
-
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error
|
162
|
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
169
|
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
+
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
|
171
|
+
if (error) { \
|
172
|
+
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
|
+
return NULL; \
|
174
|
+
}
|
163
175
|
|
164
176
|
GGML_METAL_ADD_KERNEL(add);
|
165
177
|
GGML_METAL_ADD_KERNEL(add_row);
|
@@ -189,6 +201,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
189
201
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
190
202
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
191
203
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
206
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
209
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
210
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
211
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
192
212
|
GGML_METAL_ADD_KERNEL(rope);
|
193
213
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
194
214
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
@@ -221,11 +241,12 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
|
221
241
|
ctx->n_cb = n_cb;
|
222
242
|
}
|
223
243
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
244
|
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
245
|
+
return ctx->concur_list_len;
|
246
|
+
}
|
247
|
+
|
248
|
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
249
|
+
return ctx->concur_list;
|
229
250
|
}
|
230
251
|
|
231
252
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
@@ -368,17 +389,17 @@ void ggml_metal_get_tensor(
|
|
368
389
|
|
369
390
|
void ggml_metal_graph_find_concurrency(
|
370
391
|
struct ggml_metal_context * ctx,
|
371
|
-
struct ggml_cgraph * gf) {
|
392
|
+
struct ggml_cgraph * gf, bool check_mem) {
|
372
393
|
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
373
|
-
int nodes_unused[
|
394
|
+
int nodes_unused[GGML_MAX_CONCUR];
|
374
395
|
|
375
|
-
for (int i = 0; i <
|
376
|
-
for (int i = 0; i < gf->n_nodes;
|
396
|
+
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
397
|
+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
377
398
|
ctx->concur_list_len = 0;
|
378
399
|
|
379
|
-
int n_left
|
380
|
-
int n_start
|
381
|
-
int level_pos = 0;
|
400
|
+
int n_left = gf->n_nodes;
|
401
|
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
402
|
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
382
403
|
|
383
404
|
while (n_left > 0) {
|
384
405
|
// number of nodes at a layer (that can be issued concurrently)
|
@@ -386,28 +407,40 @@ void ggml_metal_graph_find_concurrency(
|
|
386
407
|
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
387
408
|
if (nodes_unused[i]) {
|
388
409
|
// if the requirements for gf->nodes[i] are satisfied
|
389
|
-
int exe_flag=1;
|
410
|
+
int exe_flag = 1;
|
411
|
+
|
390
412
|
// scan all srcs
|
391
413
|
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
392
414
|
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
393
415
|
if (src_cur) {
|
394
416
|
// if is leaf nodes it's satisfied.
|
395
|
-
|
417
|
+
// TODO: ggml_is_leaf()
|
418
|
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
419
|
+
continue;
|
420
|
+
}
|
396
421
|
|
397
422
|
// otherwise this src should be the output from previous nodes.
|
398
423
|
int is_found = 0;
|
424
|
+
|
399
425
|
// scan 2*search_depth back because we inserted barrier.
|
400
|
-
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
401
|
-
|
426
|
+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
427
|
+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
428
|
+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
429
|
+
is_found = 1;
|
430
|
+
break;
|
431
|
+
}
|
432
|
+
}
|
433
|
+
if (is_found == 0) {
|
434
|
+
exe_flag = 0;
|
435
|
+
break;
|
402
436
|
}
|
403
|
-
if (is_found == 0) {exe_flag = 0; break;}
|
404
437
|
}
|
405
438
|
}
|
406
|
-
if (exe_flag) {
|
439
|
+
if (exe_flag && check_mem) {
|
407
440
|
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
408
441
|
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
409
442
|
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
410
|
-
int64_t length
|
443
|
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
411
444
|
for (int j = n_start; j < i; j++) {
|
412
445
|
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
413
446
|
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
@@ -416,9 +449,9 @@ void ggml_metal_graph_find_concurrency(
|
|
416
449
|
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
417
450
|
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
418
451
|
continue;
|
419
|
-
} else {
|
420
|
-
exe_flag = 0;
|
421
452
|
}
|
453
|
+
|
454
|
+
exe_flag = 0;
|
422
455
|
}
|
423
456
|
}
|
424
457
|
}
|
@@ -435,11 +468,13 @@ void ggml_metal_graph_find_concurrency(
|
|
435
468
|
ctx->concur_list[level_pos + concurrency] = -1;
|
436
469
|
ctx->concur_list_len++;
|
437
470
|
// jump all sorted nodes at nodes_bak
|
438
|
-
while (!nodes_unused[n_start]) {
|
471
|
+
while (!nodes_unused[n_start]) {
|
472
|
+
n_start++;
|
473
|
+
}
|
439
474
|
level_pos += concurrency + 1;
|
440
475
|
}
|
441
476
|
|
442
|
-
if (ctx->concur_list_len >
|
477
|
+
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
443
478
|
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
444
479
|
}
|
445
480
|
}
|
@@ -453,7 +488,7 @@ void ggml_metal_graph_compute(
|
|
453
488
|
// else fallback to serial dispatch
|
454
489
|
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
455
490
|
|
456
|
-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <=
|
491
|
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
457
492
|
|
458
493
|
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
459
494
|
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
@@ -485,7 +520,7 @@ void ggml_metal_graph_compute(
|
|
485
520
|
|
486
521
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
487
522
|
|
488
|
-
id<MTLComputeCommandEncoder> encoder =
|
523
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
489
524
|
|
490
525
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
491
526
|
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
@@ -494,10 +529,6 @@ void ggml_metal_graph_compute(
|
|
494
529
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
495
530
|
|
496
531
|
if (i == -1) {
|
497
|
-
if (encoder == nil) {
|
498
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
499
|
-
continue;
|
500
|
-
}
|
501
532
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
502
533
|
continue;
|
503
534
|
}
|
@@ -571,10 +602,6 @@ void ggml_metal_graph_compute(
|
|
571
602
|
} break;
|
572
603
|
case GGML_OP_ADD:
|
573
604
|
{
|
574
|
-
if (encoder == nil) {
|
575
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
576
|
-
}
|
577
|
-
|
578
605
|
if (ggml_nelements(src1) == ne10) {
|
579
606
|
// src1 is a row
|
580
607
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -592,10 +619,6 @@ void ggml_metal_graph_compute(
|
|
592
619
|
} break;
|
593
620
|
case GGML_OP_MUL:
|
594
621
|
{
|
595
|
-
if (encoder == nil) {
|
596
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
597
|
-
}
|
598
|
-
|
599
622
|
if (ggml_nelements(src1) == ne10) {
|
600
623
|
// src1 is a row
|
601
624
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -613,10 +636,6 @@ void ggml_metal_graph_compute(
|
|
613
636
|
} break;
|
614
637
|
case GGML_OP_SCALE:
|
615
638
|
{
|
616
|
-
if (encoder == nil) {
|
617
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
618
|
-
}
|
619
|
-
|
620
639
|
const float scale = *(const float *) src1->data;
|
621
640
|
|
622
641
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -632,10 +651,6 @@ void ggml_metal_graph_compute(
|
|
632
651
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
633
652
|
case GGML_UNARY_OP_SILU:
|
634
653
|
{
|
635
|
-
if (encoder == nil) {
|
636
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
637
|
-
}
|
638
|
-
|
639
654
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
640
655
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
641
656
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -646,10 +661,6 @@ void ggml_metal_graph_compute(
|
|
646
661
|
} break;
|
647
662
|
case GGML_UNARY_OP_RELU:
|
648
663
|
{
|
649
|
-
if (encoder == nil) {
|
650
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
651
|
-
}
|
652
|
-
|
653
664
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
654
665
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
655
666
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -660,10 +671,6 @@ void ggml_metal_graph_compute(
|
|
660
671
|
} break;
|
661
672
|
case GGML_UNARY_OP_GELU:
|
662
673
|
{
|
663
|
-
if (encoder == nil) {
|
664
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
665
|
-
}
|
666
|
-
|
667
674
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
668
675
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
669
676
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -680,10 +687,6 @@ void ggml_metal_graph_compute(
|
|
680
687
|
} break;
|
681
688
|
case GGML_OP_SOFT_MAX:
|
682
689
|
{
|
683
|
-
if (encoder == nil) {
|
684
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
685
|
-
}
|
686
|
-
|
687
690
|
const int nth = 32;
|
688
691
|
|
689
692
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
@@ -698,10 +701,6 @@ void ggml_metal_graph_compute(
|
|
698
701
|
} break;
|
699
702
|
case GGML_OP_DIAG_MASK_INF:
|
700
703
|
{
|
701
|
-
if (encoder == nil) {
|
702
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
703
|
-
}
|
704
|
-
|
705
704
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
706
705
|
|
707
706
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
@@ -719,53 +718,43 @@ void ggml_metal_graph_compute(
|
|
719
718
|
|
720
719
|
GGML_ASSERT(ne00 == ne10);
|
721
720
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
721
|
+
uint gqa = ne12/ne02;
|
722
722
|
GGML_ASSERT(ne03 == ne13);
|
723
723
|
|
724
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
725
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
724
726
|
if (ggml_is_contiguous(src0) &&
|
725
727
|
ggml_is_contiguous(src1) &&
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
755
|
-
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
756
|
-
size_t offs_dst_cur = offs_dst + i02*nb2;
|
757
|
-
|
758
|
-
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
759
|
-
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
760
|
-
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
761
|
-
|
762
|
-
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
763
|
-
}
|
764
|
-
} else {
|
765
|
-
if (encoder == nil) {
|
766
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
728
|
+
src1t == GGML_TYPE_F32 &&
|
729
|
+
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
|
+
ne00%32 == 0 &&
|
731
|
+
ne11 > 1) {
|
732
|
+
switch (src0->type) {
|
733
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
734
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
735
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
736
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
737
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
738
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
739
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
740
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
741
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
742
|
+
}
|
743
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
767
756
|
}
|
768
|
-
|
757
|
+
else {
|
769
758
|
int nth0 = 32;
|
770
759
|
int nth1 = 1;
|
771
760
|
|
@@ -864,23 +853,24 @@ void ggml_metal_graph_compute(
|
|
864
853
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
865
854
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
866
855
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
867
857
|
|
868
858
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
869
859
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11,
|
860
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
861
|
}
|
872
862
|
else if (src0t == GGML_TYPE_Q3_K) {
|
873
863
|
#ifdef GGML_QKK_64
|
874
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
864
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
875
865
|
#else
|
876
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11,
|
866
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
877
867
|
#endif
|
878
868
|
}
|
879
869
|
else if (src0t == GGML_TYPE_Q5_K) {
|
880
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11,
|
870
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
881
871
|
}
|
882
872
|
else if (src0t == GGML_TYPE_Q6_K) {
|
883
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
873
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
884
874
|
} else {
|
885
875
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
886
876
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -889,10 +879,6 @@ void ggml_metal_graph_compute(
|
|
889
879
|
} break;
|
890
880
|
case GGML_OP_GET_ROWS:
|
891
881
|
{
|
892
|
-
if (encoder == nil) {
|
893
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
894
|
-
}
|
895
|
-
|
896
882
|
switch (src0->type) {
|
897
883
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
898
884
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
@@ -918,10 +904,6 @@ void ggml_metal_graph_compute(
|
|
918
904
|
} break;
|
919
905
|
case GGML_OP_RMS_NORM:
|
920
906
|
{
|
921
|
-
if (encoder == nil) {
|
922
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
923
|
-
}
|
924
|
-
|
925
907
|
float eps;
|
926
908
|
memcpy(&eps, dst->op_params, sizeof(float));
|
927
909
|
|
@@ -941,10 +923,6 @@ void ggml_metal_graph_compute(
|
|
941
923
|
} break;
|
942
924
|
case GGML_OP_NORM:
|
943
925
|
{
|
944
|
-
if (encoder == nil) {
|
945
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
946
|
-
}
|
947
|
-
|
948
926
|
const float eps = 1e-5f;
|
949
927
|
|
950
928
|
const int nth = 256;
|
@@ -963,10 +941,6 @@ void ggml_metal_graph_compute(
|
|
963
941
|
} break;
|
964
942
|
case GGML_OP_ALIBI:
|
965
943
|
{
|
966
|
-
if (encoder == nil) {
|
967
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
968
|
-
}
|
969
|
-
|
970
944
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
971
945
|
|
972
946
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
@@ -1006,10 +980,6 @@ void ggml_metal_graph_compute(
|
|
1006
980
|
} break;
|
1007
981
|
case GGML_OP_ROPE:
|
1008
982
|
{
|
1009
|
-
if (encoder == nil) {
|
1010
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1011
|
-
}
|
1012
|
-
|
1013
983
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1014
984
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1015
985
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1050,10 +1020,6 @@ void ggml_metal_graph_compute(
|
|
1050
1020
|
case GGML_OP_CPY:
|
1051
1021
|
case GGML_OP_CONT:
|
1052
1022
|
{
|
1053
|
-
if (encoder == nil) {
|
1054
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1055
|
-
}
|
1056
|
-
|
1057
1023
|
const int nth = 32;
|
1058
1024
|
|
1059
1025
|
switch (src0t) {
|