llama_cpp 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -0
- data/README.md +1 -1
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +3 -3
- data/ext/llama_cpp/llama_cpp.cpp +118 -117
- data/ext/llama_cpp/src/ggml-alloc.c +97 -53
- data/ext/llama_cpp/src/ggml-alloc.h +4 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1010 -497
- data/ext/llama_cpp/src/ggml-cuda.h +32 -23
- data/ext/llama_cpp/src/ggml-metal.h +9 -3
- data/ext/llama_cpp/src/ggml-metal.m +142 -161
- data/ext/llama_cpp/src/ggml-metal.metal +577 -500
- data/ext/llama_cpp/src/ggml.c +2064 -233
- data/ext/llama_cpp/src/ggml.h +238 -13
- data/ext/llama_cpp/src/k_quants.c +110 -54
- data/ext/llama_cpp/src/llama-util.h +10 -8
- data/ext/llama_cpp/src/llama.cpp +4544 -2890
- data/ext/llama_cpp/src/llama.h +133 -123
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +8 -8
- metadata +2 -2
@@ -5,7 +5,6 @@
|
|
5
5
|
#import <Foundation/Foundation.h>
|
6
6
|
|
7
7
|
#import <Metal/Metal.h>
|
8
|
-
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
9
8
|
|
10
9
|
#undef MIN
|
11
10
|
#undef MAX
|
@@ -64,6 +63,7 @@ struct ggml_metal_context {
|
|
64
63
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
65
64
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
66
65
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
66
|
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
67
67
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
68
68
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
69
69
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
@@ -74,11 +74,21 @@ struct ggml_metal_context {
|
|
74
74
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
75
75
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
76
76
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
77
|
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
77
78
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
78
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
79
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
80
81
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
81
82
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
83
|
+
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
84
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
85
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
87
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
88
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
89
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
90
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
91
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
82
92
|
GGML_METAL_DECL_KERNEL(rope);
|
83
93
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
84
94
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
@@ -110,13 +120,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
110
120
|
ctx->n_buffers = 0;
|
111
121
|
ctx->concur_list_len = 0;
|
112
122
|
|
113
|
-
// determine if we can use MPS
|
114
|
-
if (MPSSupportsMTLDevice(ctx->device)) {
|
115
|
-
fprintf(stderr, "%s: using MPS\n", __func__);
|
116
|
-
} else {
|
117
|
-
fprintf(stderr, "%s: not using MPS\n", __func__);
|
118
|
-
GGML_ASSERT(false && "MPS not supported");
|
119
|
-
}
|
120
123
|
|
121
124
|
#if 0
|
122
125
|
// compile from source string and show compile log
|
@@ -126,7 +129,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
126
129
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
127
130
|
if (error) {
|
128
131
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
129
|
-
|
132
|
+
return NULL;
|
130
133
|
}
|
131
134
|
}
|
132
135
|
#else
|
@@ -144,7 +147,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
144
147
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
145
148
|
if (error) {
|
146
149
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
147
|
-
|
150
|
+
return NULL;
|
148
151
|
}
|
149
152
|
|
150
153
|
#ifdef GGML_QKK_64
|
@@ -156,17 +159,24 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
156
159
|
#endif
|
157
160
|
if (error) {
|
158
161
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
159
|
-
|
162
|
+
return NULL;
|
160
163
|
}
|
161
164
|
}
|
162
165
|
#endif
|
163
166
|
|
164
167
|
// load kernels
|
165
168
|
{
|
169
|
+
NSError * error = nil;
|
166
170
|
#define GGML_METAL_ADD_KERNEL(name) \
|
167
171
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
168
|
-
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error
|
169
|
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name
|
172
|
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
173
|
+
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
176
|
+
if (error) { \
|
177
|
+
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
178
|
+
return NULL; \
|
179
|
+
}
|
170
180
|
|
171
181
|
GGML_METAL_ADD_KERNEL(add);
|
172
182
|
GGML_METAL_ADD_KERNEL(add_row);
|
@@ -181,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
181
191
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
182
192
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
183
193
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
194
|
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
184
195
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
185
196
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
186
197
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
@@ -191,11 +202,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
191
202
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
192
203
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
193
204
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
205
|
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
194
206
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
195
207
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
196
208
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
197
209
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
198
210
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
211
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
212
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
213
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
214
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
215
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
216
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
217
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
218
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
219
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
199
220
|
GGML_METAL_ADD_KERNEL(rope);
|
200
221
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
201
222
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
@@ -205,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
205
226
|
#undef GGML_METAL_ADD_KERNEL
|
206
227
|
}
|
207
228
|
|
208
|
-
fprintf(stderr, "%s: recommendedMaxWorkingSetSize
|
209
|
-
fprintf(stderr, "%s: hasUnifiedMemory
|
229
|
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
230
|
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
210
231
|
if (ctx->device.maxTransferRate != 0) {
|
211
|
-
fprintf(stderr, "%s: maxTransferRate
|
232
|
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
212
233
|
} else {
|
213
|
-
fprintf(stderr, "%s: maxTransferRate
|
234
|
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
214
235
|
}
|
215
236
|
|
216
237
|
return ctx;
|
@@ -224,15 +245,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
224
245
|
free(ctx);
|
225
246
|
}
|
226
247
|
|
248
|
+
void * ggml_metal_host_malloc(size_t n) {
|
249
|
+
void * data = NULL;
|
250
|
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
251
|
+
if (result != 0) {
|
252
|
+
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
|
253
|
+
return NULL;
|
254
|
+
}
|
255
|
+
|
256
|
+
return data;
|
257
|
+
}
|
258
|
+
|
259
|
+
void ggml_metal_host_free(void * data) {
|
260
|
+
free(data);
|
261
|
+
}
|
262
|
+
|
227
263
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
228
264
|
ctx->n_cb = n_cb;
|
229
265
|
}
|
230
266
|
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
267
|
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
268
|
+
return ctx->concur_list_len;
|
269
|
+
}
|
270
|
+
|
271
|
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
272
|
+
return ctx->concur_list;
|
236
273
|
}
|
237
274
|
|
238
275
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
@@ -375,7 +412,7 @@ void ggml_metal_get_tensor(
|
|
375
412
|
|
376
413
|
void ggml_metal_graph_find_concurrency(
|
377
414
|
struct ggml_metal_context * ctx,
|
378
|
-
struct ggml_cgraph * gf) {
|
415
|
+
struct ggml_cgraph * gf, bool check_mem) {
|
379
416
|
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
380
417
|
int nodes_unused[GGML_MAX_CONCUR];
|
381
418
|
|
@@ -422,7 +459,7 @@ void ggml_metal_graph_find_concurrency(
|
|
422
459
|
}
|
423
460
|
}
|
424
461
|
}
|
425
|
-
if (exe_flag) {
|
462
|
+
if (exe_flag && check_mem) {
|
426
463
|
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
427
464
|
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
428
465
|
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
@@ -506,19 +543,15 @@ void ggml_metal_graph_compute(
|
|
506
543
|
|
507
544
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
508
545
|
|
509
|
-
id<MTLComputeCommandEncoder> encoder =
|
546
|
+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
510
547
|
|
511
|
-
const int node_start =
|
512
|
-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
548
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
513
550
|
|
514
551
|
for (int ind = node_start; ind < node_end; ++ind) {
|
515
552
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
516
553
|
|
517
554
|
if (i == -1) {
|
518
|
-
if (encoder == nil) {
|
519
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
520
|
-
continue;
|
521
|
-
}
|
522
555
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
523
556
|
continue;
|
524
557
|
}
|
@@ -592,10 +625,6 @@ void ggml_metal_graph_compute(
|
|
592
625
|
} break;
|
593
626
|
case GGML_OP_ADD:
|
594
627
|
{
|
595
|
-
if (encoder == nil) {
|
596
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
597
|
-
}
|
598
|
-
|
599
628
|
if (ggml_nelements(src1) == ne10) {
|
600
629
|
// src1 is a row
|
601
630
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
@@ -613,10 +642,6 @@ void ggml_metal_graph_compute(
|
|
613
642
|
} break;
|
614
643
|
case GGML_OP_MUL:
|
615
644
|
{
|
616
|
-
if (encoder == nil) {
|
617
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
618
|
-
}
|
619
|
-
|
620
645
|
if (ggml_nelements(src1) == ne10) {
|
621
646
|
// src1 is a row
|
622
647
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
@@ -634,10 +659,6 @@ void ggml_metal_graph_compute(
|
|
634
659
|
} break;
|
635
660
|
case GGML_OP_SCALE:
|
636
661
|
{
|
637
|
-
if (encoder == nil) {
|
638
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
639
|
-
}
|
640
|
-
|
641
662
|
const float scale = *(const float *) src1->data;
|
642
663
|
|
643
664
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
@@ -653,10 +674,6 @@ void ggml_metal_graph_compute(
|
|
653
674
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
654
675
|
case GGML_UNARY_OP_SILU:
|
655
676
|
{
|
656
|
-
if (encoder == nil) {
|
657
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
658
|
-
}
|
659
|
-
|
660
677
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
661
678
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
662
679
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -667,10 +684,6 @@ void ggml_metal_graph_compute(
|
|
667
684
|
} break;
|
668
685
|
case GGML_UNARY_OP_RELU:
|
669
686
|
{
|
670
|
-
if (encoder == nil) {
|
671
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
672
|
-
}
|
673
|
-
|
674
687
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
675
688
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
676
689
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -681,10 +694,6 @@ void ggml_metal_graph_compute(
|
|
681
694
|
} break;
|
682
695
|
case GGML_UNARY_OP_GELU:
|
683
696
|
{
|
684
|
-
if (encoder == nil) {
|
685
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
686
|
-
}
|
687
|
-
|
688
697
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
689
698
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
690
699
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
@@ -701,10 +710,6 @@ void ggml_metal_graph_compute(
|
|
701
710
|
} break;
|
702
711
|
case GGML_OP_SOFT_MAX:
|
703
712
|
{
|
704
|
-
if (encoder == nil) {
|
705
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
706
|
-
}
|
707
|
-
|
708
713
|
const int nth = 32;
|
709
714
|
|
710
715
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
@@ -719,10 +724,6 @@ void ggml_metal_graph_compute(
|
|
719
724
|
} break;
|
720
725
|
case GGML_OP_DIAG_MASK_INF:
|
721
726
|
{
|
722
|
-
if (encoder == nil) {
|
723
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
724
|
-
}
|
725
|
-
|
726
727
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
727
728
|
|
728
729
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
@@ -740,53 +741,43 @@ void ggml_metal_graph_compute(
|
|
740
741
|
|
741
742
|
GGML_ASSERT(ne00 == ne10);
|
742
743
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
744
|
+
uint gqa = ne12/ne02;
|
743
745
|
GGML_ASSERT(ne03 == ne13);
|
744
746
|
|
747
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
748
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
745
749
|
if (ggml_is_contiguous(src0) &&
|
746
750
|
ggml_is_contiguous(src1) &&
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
|
763
|
-
|
764
|
-
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
765
|
-
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
|
766
|
-
|
767
|
-
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
768
|
-
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
769
|
-
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
770
|
-
|
771
|
-
// we need to do ne12 multiplications
|
772
|
-
// TODO: is there a way to do this in parallel - currently very slow ..
|
773
|
-
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
774
|
-
for (int64_t i02 = 0; i02 < ne12; ++i02) {
|
775
|
-
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
776
|
-
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
777
|
-
size_t offs_dst_cur = offs_dst + i02*nb2;
|
778
|
-
|
779
|
-
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
780
|
-
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
781
|
-
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
782
|
-
|
783
|
-
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
751
|
+
src1t == GGML_TYPE_F32 &&
|
752
|
+
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
753
|
+
ne00%32 == 0 &&
|
754
|
+
ne11 > 1) {
|
755
|
+
switch (src0->type) {
|
756
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
757
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
758
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
759
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
760
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
761
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
762
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
763
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
764
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
765
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
784
766
|
}
|
767
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
768
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
769
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
770
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
771
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
772
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
773
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
774
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
775
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
776
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
777
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
778
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
779
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
785
780
|
} else {
|
786
|
-
if (encoder == nil) {
|
787
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
788
|
-
}
|
789
|
-
|
790
781
|
int nth0 = 32;
|
791
782
|
int nth1 = 1;
|
792
783
|
|
@@ -816,6 +807,15 @@ void ggml_metal_graph_compute(
|
|
816
807
|
nth1 = 8;
|
817
808
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
818
809
|
} break;
|
810
|
+
case GGML_TYPE_Q8_0:
|
811
|
+
{
|
812
|
+
GGML_ASSERT(ne02 == 1);
|
813
|
+
GGML_ASSERT(ne12 == 1);
|
814
|
+
|
815
|
+
nth0 = 8;
|
816
|
+
nth1 = 8;
|
817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
818
|
+
} break;
|
819
819
|
case GGML_TYPE_Q2_K:
|
820
820
|
{
|
821
821
|
GGML_ASSERT(ne02 == 1);
|
@@ -885,23 +885,24 @@ void ggml_metal_graph_compute(
|
|
885
885
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
886
886
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
887
887
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
888
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
888
889
|
|
889
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
890
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
890
891
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
891
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)
|
892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
892
893
|
}
|
893
894
|
else if (src0t == GGML_TYPE_Q3_K) {
|
894
895
|
#ifdef GGML_QKK_64
|
895
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
896
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
896
897
|
#else
|
897
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11,
|
898
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
898
899
|
#endif
|
899
900
|
}
|
900
901
|
else if (src0t == GGML_TYPE_Q5_K) {
|
901
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)
|
902
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
902
903
|
}
|
903
904
|
else if (src0t == GGML_TYPE_Q6_K) {
|
904
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11,
|
905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
905
906
|
} else {
|
906
907
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
907
908
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -910,14 +911,11 @@ void ggml_metal_graph_compute(
|
|
910
911
|
} break;
|
911
912
|
case GGML_OP_GET_ROWS:
|
912
913
|
{
|
913
|
-
if (encoder == nil) {
|
914
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
915
|
-
}
|
916
|
-
|
917
914
|
switch (src0->type) {
|
918
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
915
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
919
916
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
920
917
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
918
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
921
919
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
922
920
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
923
921
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
@@ -939,10 +937,6 @@ void ggml_metal_graph_compute(
|
|
939
937
|
} break;
|
940
938
|
case GGML_OP_RMS_NORM:
|
941
939
|
{
|
942
|
-
if (encoder == nil) {
|
943
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
944
|
-
}
|
945
|
-
|
946
940
|
float eps;
|
947
941
|
memcpy(&eps, dst->op_params, sizeof(float));
|
948
942
|
|
@@ -962,20 +956,17 @@ void ggml_metal_graph_compute(
|
|
962
956
|
} break;
|
963
957
|
case GGML_OP_NORM:
|
964
958
|
{
|
965
|
-
|
966
|
-
|
967
|
-
}
|
968
|
-
|
969
|
-
const float eps = 1e-5f;
|
959
|
+
float eps;
|
960
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
970
961
|
|
971
962
|
const int nth = 256;
|
972
963
|
|
973
964
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
974
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
975
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
976
|
-
[encoder setBytes:&ne00
|
977
|
-
[encoder setBytes:&nb01
|
978
|
-
[encoder setBytes:&eps
|
965
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
966
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
967
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
968
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
969
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
979
970
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
980
971
|
|
981
972
|
const int64_t nrows = ggml_nrows(src0);
|
@@ -984,10 +975,6 @@ void ggml_metal_graph_compute(
|
|
984
975
|
} break;
|
985
976
|
case GGML_OP_ALIBI:
|
986
977
|
{
|
987
|
-
if (encoder == nil) {
|
988
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
989
|
-
}
|
990
|
-
|
991
978
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
992
979
|
|
993
980
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
@@ -1022,15 +1009,13 @@ void ggml_metal_graph_compute(
|
|
1022
1009
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1023
1010
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1024
1011
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1012
|
+
|
1025
1013
|
const int nth = 32;
|
1014
|
+
|
1026
1015
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1027
1016
|
} break;
|
1028
1017
|
case GGML_OP_ROPE:
|
1029
1018
|
{
|
1030
|
-
if (encoder == nil) {
|
1031
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1032
|
-
}
|
1033
|
-
|
1034
1019
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1035
1020
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1036
1021
|
const int mode = ((int32_t *) dst->op_params)[2];
|
@@ -1041,8 +1026,8 @@ void ggml_metal_graph_compute(
|
|
1041
1026
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
1042
1027
|
|
1043
1028
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
1044
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1045
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1029
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1030
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1046
1031
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1047
1032
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1048
1033
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
@@ -1071,10 +1056,6 @@ void ggml_metal_graph_compute(
|
|
1071
1056
|
case GGML_OP_CPY:
|
1072
1057
|
case GGML_OP_CONT:
|
1073
1058
|
{
|
1074
|
-
if (encoder == nil) {
|
1075
|
-
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
1076
|
-
}
|
1077
|
-
|
1078
1059
|
const int nth = 32;
|
1079
1060
|
|
1080
1061
|
switch (src0t) {
|
@@ -1097,24 +1078,24 @@ void ggml_metal_graph_compute(
|
|
1097
1078
|
default: GGML_ASSERT(false && "not implemented");
|
1098
1079
|
}
|
1099
1080
|
|
1100
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1101
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1102
|
-
[encoder setBytes:&ne00
|
1103
|
-
[encoder setBytes:&ne01
|
1104
|
-
[encoder setBytes:&ne02
|
1105
|
-
[encoder setBytes:&ne03
|
1106
|
-
[encoder setBytes:&nb00
|
1107
|
-
[encoder setBytes:&nb01
|
1108
|
-
[encoder setBytes:&nb02
|
1109
|
-
[encoder setBytes:&nb03
|
1110
|
-
[encoder setBytes:&ne0
|
1111
|
-
[encoder setBytes:&ne1
|
1112
|
-
[encoder setBytes:&ne2
|
1113
|
-
[encoder setBytes:&ne3
|
1114
|
-
[encoder setBytes:&nb0
|
1115
|
-
[encoder setBytes:&nb1
|
1116
|
-
[encoder setBytes:&nb2
|
1117
|
-
[encoder setBytes:&nb3
|
1081
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1082
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1083
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1084
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1085
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1086
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1087
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1088
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1089
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1090
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1091
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1092
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1093
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1094
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1095
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1096
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1097
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1098
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1118
1099
|
|
1119
1100
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1120
1101
|
} break;
|