llama_cpp 0.5.1 → 0.5.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -3
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +32 -2
- data/ext/llama_cpp/src/ggml-alloc.c +6 -11
- data/ext/llama_cpp/src/ggml-cuda.cu +1108 -699
- data/ext/llama_cpp/src/ggml-metal.m +93 -24
- data/ext/llama_cpp/src/ggml-metal.metal +407 -174
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +75 -43
- data/ext/llama_cpp/src/ggml.h +42 -32
- data/ext/llama_cpp/src/k_quants.c +4 -1
- data/ext/llama_cpp/src/llama.cpp +1040 -201
- data/ext/llama_cpp/src/llama.h +13 -7
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -63,7 +63,10 @@ struct ggml_metal_context {
|
|
63
63
|
GGML_METAL_DECL_KERNEL(relu);
|
64
64
|
GGML_METAL_DECL_KERNEL(gelu);
|
65
65
|
GGML_METAL_DECL_KERNEL(soft_max);
|
66
|
+
GGML_METAL_DECL_KERNEL(soft_max_4);
|
66
67
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
68
|
+
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
69
|
+
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
67
70
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
68
71
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
69
72
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
@@ -75,8 +78,10 @@ struct ggml_metal_context {
|
|
75
78
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
76
79
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
77
80
|
GGML_METAL_DECL_KERNEL(norm);
|
81
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
78
82
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
79
83
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
80
85
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
81
86
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
82
87
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
@@ -85,6 +90,7 @@ struct ggml_metal_context {
|
|
85
90
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
86
91
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
87
92
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
88
94
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
89
95
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
90
96
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
@@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
117
123
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
118
124
|
metal_printf("%s: allocating\n", __func__);
|
119
125
|
|
120
|
-
// Show all the Metal device instances in the system
|
121
|
-
NSArray * devices = MTLCopyAllDevices();
|
122
126
|
id <MTLDevice> device;
|
123
127
|
NSString * s;
|
128
|
+
|
129
|
+
#if TARGET_OS_OSX
|
130
|
+
// Show all the Metal device instances in the system
|
131
|
+
NSArray * devices = MTLCopyAllDevices();
|
124
132
|
for (device in devices) {
|
125
133
|
s = [device name];
|
126
134
|
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
127
135
|
}
|
136
|
+
#endif
|
128
137
|
|
129
138
|
// Pick and show default Metal device
|
130
139
|
device = MTLCreateSystemDefaultDevice();
|
@@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
139
148
|
ctx->n_buffers = 0;
|
140
149
|
ctx->concur_list_len = 0;
|
141
150
|
|
142
|
-
ctx->d_queue = dispatch_queue_create("
|
151
|
+
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
143
152
|
|
144
|
-
#
|
145
|
-
//
|
153
|
+
#ifdef GGML_SWIFT
|
154
|
+
// load the default.metallib file
|
146
155
|
{
|
147
156
|
NSError * error = nil;
|
148
157
|
|
149
|
-
|
158
|
+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
159
|
+
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
160
|
+
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
161
|
+
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
162
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
163
|
+
|
164
|
+
// Load the metallib file into a Metal library
|
165
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
166
|
+
|
150
167
|
if (error) {
|
151
168
|
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
152
169
|
return NULL;
|
@@ -161,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
161
178
|
|
162
179
|
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
163
180
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
164
|
-
NSString * path
|
181
|
+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
165
182
|
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
166
183
|
|
167
184
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
@@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
207
224
|
GGML_METAL_ADD_KERNEL(relu);
|
208
225
|
GGML_METAL_ADD_KERNEL(gelu);
|
209
226
|
GGML_METAL_ADD_KERNEL(soft_max);
|
227
|
+
GGML_METAL_ADD_KERNEL(soft_max_4);
|
210
228
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
229
|
+
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
230
|
+
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
211
231
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
212
232
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
213
233
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
@@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
219
239
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
220
240
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
221
241
|
GGML_METAL_ADD_KERNEL(norm);
|
242
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
222
243
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
223
244
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
245
|
+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
224
246
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
225
247
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
226
248
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
@@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
229
251
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
230
252
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
231
253
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
254
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
232
255
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
233
256
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
234
257
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
@@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
247
270
|
#undef GGML_METAL_ADD_KERNEL
|
248
271
|
}
|
249
272
|
|
250
|
-
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
251
273
|
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
274
|
+
#if TARGET_OS_OSX
|
275
|
+
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
252
276
|
if (ctx->device.maxTransferRate != 0) {
|
253
277
|
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
254
278
|
} else {
|
255
279
|
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
256
280
|
}
|
281
|
+
#endif
|
257
282
|
|
258
283
|
return ctx;
|
259
284
|
}
|
@@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
273
298
|
GGML_METAL_DEL_KERNEL(relu);
|
274
299
|
GGML_METAL_DEL_KERNEL(gelu);
|
275
300
|
GGML_METAL_DEL_KERNEL(soft_max);
|
301
|
+
GGML_METAL_DEL_KERNEL(soft_max_4);
|
276
302
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
303
|
+
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
304
|
+
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
277
305
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
278
306
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
279
307
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
@@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
285
313
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
286
314
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
287
315
|
GGML_METAL_DEL_KERNEL(norm);
|
316
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
288
317
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
289
318
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
319
|
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
290
320
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
291
321
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
292
322
|
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
@@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
295
325
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
296
326
|
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
297
327
|
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
328
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
298
329
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
299
330
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
300
331
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
@@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
365
396
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
366
397
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
367
398
|
|
399
|
+
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
368
400
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
369
401
|
*offs = (size_t) ioffs;
|
370
402
|
|
@@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
|
|
454
486
|
}
|
455
487
|
}
|
456
488
|
|
489
|
+
#if TARGET_OS_OSX
|
457
490
|
metal_printf(", (%8.2f / %8.2f)",
|
458
491
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
459
492
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
@@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
|
|
463
496
|
} else {
|
464
497
|
metal_printf("\n");
|
465
498
|
}
|
499
|
+
#else
|
500
|
+
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
501
|
+
#endif
|
466
502
|
}
|
467
503
|
|
468
504
|
return true;
|
@@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
|
|
698
734
|
case GGML_OP_ADD:
|
699
735
|
{
|
700
736
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
737
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
701
738
|
|
702
739
|
// utilize float4
|
703
740
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
|
|
705
742
|
|
706
743
|
if (ggml_nelements(src1) == ne10) {
|
707
744
|
// src1 is a row
|
745
|
+
GGML_ASSERT(ne11 == 1);
|
708
746
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
709
747
|
} else {
|
710
748
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
@@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
|
|
721
759
|
case GGML_OP_MUL:
|
722
760
|
{
|
723
761
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
762
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
724
763
|
|
725
764
|
// utilize float4
|
726
765
|
GGML_ASSERT(ne00 % 4 == 0);
|
@@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
|
|
728
767
|
|
729
768
|
if (ggml_nelements(src1) == ne10) {
|
730
769
|
// src1 is a row
|
770
|
+
GGML_ASSERT(ne11 == 1);
|
731
771
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
732
772
|
} else {
|
733
773
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
@@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
|
|
743
783
|
} break;
|
744
784
|
case GGML_OP_SCALE:
|
745
785
|
{
|
786
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
787
|
+
|
746
788
|
const float scale = *(const float *) src1->data;
|
747
789
|
|
748
790
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
|
|
750
792
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
751
793
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
752
794
|
|
753
|
-
const int64_t n = ggml_nelements(dst);
|
795
|
+
const int64_t n = ggml_nelements(dst)/4;
|
754
796
|
|
755
797
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
756
798
|
} break;
|
@@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
|
|
762
804
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
763
805
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
764
806
|
|
765
|
-
const int64_t n = ggml_nelements(dst);
|
807
|
+
const int64_t n = ggml_nelements(dst)/4;
|
766
808
|
|
767
809
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
768
810
|
} break;
|
@@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
|
|
782
824
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
783
825
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
784
826
|
|
785
|
-
const int64_t n = ggml_nelements(dst);
|
827
|
+
const int64_t n = ggml_nelements(dst)/4;
|
786
828
|
|
787
829
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
788
830
|
} break;
|
@@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
|
|
796
838
|
{
|
797
839
|
const int nth = 32;
|
798
840
|
|
799
|
-
|
841
|
+
if (ne00%4 == 0) {
|
842
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
843
|
+
} else {
|
844
|
+
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
845
|
+
}
|
800
846
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
801
847
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
802
848
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
803
849
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
804
850
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
805
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
806
851
|
|
807
852
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
808
853
|
} break;
|
@@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
|
|
810
855
|
{
|
811
856
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
812
857
|
|
813
|
-
|
858
|
+
if (ne00%8 == 0) {
|
859
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
860
|
+
} else {
|
861
|
+
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
862
|
+
}
|
814
863
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
815
864
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
816
865
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
817
866
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
818
867
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
819
868
|
|
820
|
-
|
869
|
+
if (ne00%8 == 0) {
|
870
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
871
|
+
}
|
872
|
+
else {
|
873
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
874
|
+
}
|
821
875
|
} break;
|
822
876
|
case GGML_OP_MUL_MAT:
|
823
877
|
{
|
@@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
|
|
830
884
|
|
831
885
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
832
886
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
833
|
-
if (
|
834
|
-
|
887
|
+
if (!ggml_is_transposed(src0) &&
|
888
|
+
!ggml_is_transposed(src1) &&
|
835
889
|
src1t == GGML_TYPE_F32 &&
|
836
890
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
837
891
|
ne00%32 == 0 &&
|
838
892
|
ne11 > 1) {
|
839
893
|
switch (src0->type) {
|
894
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
840
895
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
841
896
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
842
897
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
@@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
|
|
856
911
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
857
912
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
858
913
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
859
|
-
[encoder setBytes:&
|
860
|
-
[encoder setBytes:&
|
861
|
-
[encoder setBytes:&
|
914
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
915
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
916
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
917
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
918
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
919
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
862
920
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
863
921
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
864
922
|
} else {
|
865
923
|
int nth0 = 32;
|
866
924
|
int nth1 = 1;
|
925
|
+
int nrows = 1;
|
867
926
|
|
868
927
|
// use custom matrix x vector kernel
|
869
928
|
switch (src0t) {
|
929
|
+
case GGML_TYPE_F32:
|
930
|
+
{
|
931
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
932
|
+
nrows = 4;
|
933
|
+
} break;
|
870
934
|
case GGML_TYPE_F16:
|
871
935
|
{
|
872
936
|
nth0 = 32;
|
873
937
|
nth1 = 1;
|
874
938
|
if (ne11 * ne12 < 4) {
|
875
939
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
940
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
941
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
942
|
+
nrows = ne11;
|
876
943
|
} else {
|
877
944
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
945
|
+
nrows = 4;
|
878
946
|
}
|
879
947
|
} break;
|
880
948
|
case GGML_TYPE_Q4_0:
|
@@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
|
|
995
1063
|
else if (src0t == GGML_TYPE_Q6_K) {
|
996
1064
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
997
1065
|
} else {
|
998
|
-
int64_t ny = (ne11 +
|
1066
|
+
int64_t ny = (ne11 + nrows - 1)/nrows;
|
999
1067
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1000
1068
|
}
|
1001
1069
|
}
|
@@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
|
|
1003
1071
|
case GGML_OP_GET_ROWS:
|
1004
1072
|
{
|
1005
1073
|
switch (src0->type) {
|
1074
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
1006
1075
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1007
1076
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1008
1077
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
@@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
|
|
1018
1087
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1019
1088
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1020
1089
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1021
|
-
[encoder setBytes:&
|
1022
|
-
[encoder setBytes:&
|
1023
|
-
[encoder setBytes:&
|
1090
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1091
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1092
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
1024
1093
|
|
1025
1094
|
const int64_t n = ggml_nelements(src1);
|
1026
1095
|
|