llama_cpp 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -3
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +32 -2
- data/ext/llama_cpp/src/ggml-alloc.c +6 -11
- data/ext/llama_cpp/src/ggml-cuda.cu +1108 -699
- data/ext/llama_cpp/src/ggml-metal.m +93 -24
- data/ext/llama_cpp/src/ggml-metal.metal +407 -174
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +75 -43
- data/ext/llama_cpp/src/ggml.h +42 -32
- data/ext/llama_cpp/src/k_quants.c +4 -1
- data/ext/llama_cpp/src/llama.cpp +1040 -201
- data/ext/llama_cpp/src/llama.h +13 -7
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -63,7 +63,10 @@ struct ggml_metal_context {
|
|
63
63
|
GGML_METAL_DECL_KERNEL(relu);
|
64
64
|
GGML_METAL_DECL_KERNEL(gelu);
|
65
65
|
GGML_METAL_DECL_KERNEL(soft_max);
|
66
|
+
GGML_METAL_DECL_KERNEL(soft_max_4);
|
66
67
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
68
|
+
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
69
|
+
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
67
70
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
68
71
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
69
72
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
@@ -75,8 +78,10 @@ struct ggml_metal_context {
|
|
75
78
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
76
79
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
77
80
|
GGML_METAL_DECL_KERNEL(norm);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
78
82
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
79
83
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
80
85
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
81
86
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
82
87
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
@@ -85,6 +90,7 @@ struct ggml_metal_context {
|
|
85
90
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
86
91
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
87
92
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
88
94
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
89
95
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
90
96
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
@@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
117
123
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
118
124
|
metal_printf("%s: allocating\n", __func__);
|
119
125
|
|
120
|
-
// Show all the Metal device instances in the system
|
121
|
-
NSArray * devices = MTLCopyAllDevices();
|
122
126
|
id <MTLDevice> device;
|
123
127
|
NSString * s;
|
128
|
+
|
129
|
+
#if TARGET_OS_OSX
|
130
|
+
// Show all the Metal device instances in the system
|
131
|
+
NSArray * devices = MTLCopyAllDevices();
|
124
132
|
for (device in devices) {
|
125
133
|
s = [device name];
|
126
134
|
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
127
135
|
}
|
136
|
+
#endif
|
128
137
|
|
129
138
|
// Pick and show default Metal device
|
130
139
|
device = MTLCreateSystemDefaultDevice();
|
@@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
139
148
|
ctx->n_buffers = 0;
|
140
149
|
ctx->concur_list_len = 0;
|
141
150
|
|
142
|
-
ctx->d_queue = dispatch_queue_create("
|
151
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
143
152
|
|
144
|
-
#
|
145
|
-
//
|
153
|
+
#ifdef GGML_SWIFT
|
154
|
+
// load the default.metallib file
|
146
155
|
{
|
147
156
|
NSError * error = nil;
|
148
157
|
|
149
|
-
|
158
|
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
159
|
+
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
160
|
+
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
161
|
+
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
162
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
163
|
+
|
164
|
+
// Load the metallib file into a Metal library
|
165
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
166
|
+
|
150
167
|
if (error) {
|
151
168
|
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
152
169
|
return NULL;
|
@@ -161,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
161
178
|
|
162
179
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
163
180
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
164
|
-
NSString * path
|
181
|
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
165
182
|
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
166
183
|
|
167
184
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
@@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
207
224
|
GGML_METAL_ADD_KERNEL(relu);
|
208
225
|
GGML_METAL_ADD_KERNEL(gelu);
|
209
226
|
GGML_METAL_ADD_KERNEL(soft_max);
|
227
|
+
GGML_METAL_ADD_KERNEL(soft_max_4);
|
210
228
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
229
|
+
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
230
|
+
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
211
231
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
212
232
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
213
233
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
@@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
219
239
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
220
240
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
221
241
|
GGML_METAL_ADD_KERNEL(norm);
|
242
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
222
243
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
223
244
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
245
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
224
246
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
225
247
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
226
248
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
@@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
229
251
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
230
252
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
231
253
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
254
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
232
255
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
233
256
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
234
257
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
@@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
247
270
|
#undef GGML_METAL_ADD_KERNEL
|
248
271
|
}
|
249
272
|
|
250
|
-
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
251
273
|
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
274
|
+
#if TARGET_OS_OSX
|
275
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
252
276
|
if (ctx->device.maxTransferRate != 0) {
|
253
277
|
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
254
278
|
} else {
|
255
279
|
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
256
280
|
}
|
281
|
+
#endif
|
257
282
|
|
258
283
|
return ctx;
|
259
284
|
}
|
@@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
273
298
|
GGML_METAL_DEL_KERNEL(relu);
|
274
299
|
GGML_METAL_DEL_KERNEL(gelu);
|
275
300
|
GGML_METAL_DEL_KERNEL(soft_max);
|
301
|
+
GGML_METAL_DEL_KERNEL(soft_max_4);
|
276
302
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
303
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
304
|
+
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
277
305
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
278
306
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
279
307
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
@@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
285
313
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
286
314
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
287
315
|
GGML_METAL_DEL_KERNEL(norm);
|
316
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
288
317
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
289
318
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
319
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
290
320
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
291
321
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
292
322
|
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
@@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
295
325
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
296
326
|
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
297
327
|
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
328
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
298
329
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
299
330
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
300
331
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
@@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
365
396
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
366
397
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
367
398
|
|
399
|
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
368
400
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
369
401
|
*offs = (size_t) ioffs;
|
370
402
|
|
@@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
|
|
454
486
|
}
|
455
487
|
}
|
456
488
|
|
489
|
+
#if TARGET_OS_OSX
|
457
490
|
metal_printf(", (%8.2f / %8.2f)",
|
458
491
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
459
492
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
@@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
|
|
463
496
|
} else {
|
464
497
|
metal_printf("\n");
|
465
498
|
}
|
499
|
+
#else
|
500
|
+
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
501
|
+
#endif
|
466
502
|
}
|
467
503
|
|
468
504
|
return true;
|
@@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
|
|
698
734
|
case GGML_OP_ADD:
|
699
735
|
{
|
700
736
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
737
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
701
738
|
|
702
739
|
// utilize float4
|
703
740
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
|
|
705
742
|
|
706
743
|
if (ggml_nelements(src1) == ne10) {
|
707
744
|
// src1 is a row
|
745
|
+
GGML_ASSERT(ne11 == 1);
|
708
746
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
709
747
|
} else {
|
710
748
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
@@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
|
|
721
759
|
case GGML_OP_MUL:
|
722
760
|
{
|
723
761
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
762
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
724
763
|
|
725
764
|
// utilize float4
|
726
765
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
|
|
728
767
|
|
729
768
|
if (ggml_nelements(src1) == ne10) {
|
730
769
|
// src1 is a row
|
770
|
+
GGML_ASSERT(ne11 == 1);
|
731
771
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
732
772
|
} else {
|
733
773
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
@@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
|
|
743
783
|
} break;
|
744
784
|
case GGML_OP_SCALE:
|
745
785
|
{
|
786
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
787
|
+
|
746
788
|
const float scale = *(const float *) src1->data;
|
747
789
|
|
748
790
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
|
|
750
792
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
751
793
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
752
794
|
|
753
|
-
const int64_t n = ggml_nelements(dst);
|
795
|
+
const int64_t n = ggml_nelements(dst)/4;
|
754
796
|
|
755
797
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
756
798
|
} break;
|
@@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
|
|
762
804
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
763
805
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
764
806
|
|
765
|
-
const int64_t n = ggml_nelements(dst);
|
807
|
+
const int64_t n = ggml_nelements(dst)/4;
|
766
808
|
|
767
809
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
768
810
|
} break;
|
@@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
|
|
782
824
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
783
825
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
784
826
|
|
785
|
-
const int64_t n = ggml_nelements(dst);
|
827
|
+
const int64_t n = ggml_nelements(dst)/4;
|
786
828
|
|
787
829
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
788
830
|
} break;
|
@@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
|
|
796
838
|
{
|
797
839
|
const int nth = 32;
|
798
840
|
|
799
|
-
|
841
|
+
if (ne00%4 == 0) {
|
842
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
843
|
+
} else {
|
844
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
845
|
+
}
|
800
846
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
801
847
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
802
848
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
803
849
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
804
850
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
805
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
806
851
|
|
807
852
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
808
853
|
} break;
|
@@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
|
|
810
855
|
{
|
811
856
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
812
857
|
|
813
|
-
|
858
|
+
if (ne00%8 == 0) {
|
859
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
860
|
+
} else {
|
861
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
862
|
+
}
|
814
863
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
815
864
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
816
865
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
817
866
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
818
867
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
819
868
|
|
820
|
-
|
869
|
+
if (ne00%8 == 0) {
|
870
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
871
|
+
}
|
872
|
+
else {
|
873
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
874
|
+
}
|
821
875
|
} break;
|
822
876
|
case GGML_OP_MUL_MAT:
|
823
877
|
{
|
@@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
|
|
830
884
|
|
831
885
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
832
886
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
833
|
-
if (
|
834
|
-
|
887
|
+
if (!ggml_is_transposed(src0) &&
|
888
|
+
!ggml_is_transposed(src1) &&
|
835
889
|
src1t == GGML_TYPE_F32 &&
|
836
890
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
837
891
|
ne00%32 == 0 &&
|
838
892
|
ne11 > 1) {
|
839
893
|
switch (src0->type) {
|
894
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
840
895
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
841
896
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
842
897
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
@@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
|
|
856
911
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
857
912
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
858
913
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
859
|
-
[encoder setBytes:&
|
860
|
-
[encoder setBytes:&
|
861
|
-
[encoder setBytes:&
|
914
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
915
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
916
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
917
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
918
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
919
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
862
920
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
863
921
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
864
922
|
} else {
|
865
923
|
int nth0 = 32;
|
866
924
|
int nth1 = 1;
|
925
|
+
int nrows = 1;
|
867
926
|
|
868
927
|
// use custom matrix x vector kernel
|
869
928
|
switch (src0t) {
|
929
|
+
case GGML_TYPE_F32:
|
930
|
+
{
|
931
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
932
|
+
nrows = 4;
|
933
|
+
} break;
|
870
934
|
case GGML_TYPE_F16:
|
871
935
|
{
|
872
936
|
nth0 = 32;
|
873
937
|
nth1 = 1;
|
874
938
|
if (ne11 * ne12 < 4) {
|
875
939
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
940
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
941
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
942
|
+
nrows = ne11;
|
876
943
|
} else {
|
877
944
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
945
|
+
nrows = 4;
|
878
946
|
}
|
879
947
|
} break;
|
880
948
|
case GGML_TYPE_Q4_0:
|
@@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
|
|
995
1063
|
else if (src0t == GGML_TYPE_Q6_K) {
|
996
1064
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
997
1065
|
} else {
|
998
|
-
int64_t ny = (ne11 +
|
1066
|
+
int64_t ny = (ne11 + nrows - 1)/nrows;
|
999
1067
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1000
1068
|
}
|
1001
1069
|
}
|
@@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
|
|
1003
1071
|
case GGML_OP_GET_ROWS:
|
1004
1072
|
{
|
1005
1073
|
switch (src0->type) {
|
1074
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
1006
1075
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1007
1076
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1008
1077
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
@@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
|
|
1018
1087
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1019
1088
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1020
1089
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1021
|
-
[encoder setBytes:&
|
1022
|
-
[encoder setBytes:&
|
1023
|
-
[encoder setBytes:&
|
1090
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1091
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1092
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
1024
1093
|
|
1025
1094
|
const int64_t n = ggml_nelements(src1);
|
1026
1095
|
|