llama_cpp 0.14.2 → 0.14.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +64 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +91 -21
- data/vendor/tmp/llama.cpp/ggml-alloc.c +14 -5
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +155 -125
- data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
- data/vendor/tmp/llama.cpp/ggml-common.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +1779 -10762
- data/vendor/tmp/llama.cpp/ggml-cuda.h +6 -15
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +167 -124
- data/vendor/tmp/llama.cpp/ggml-metal.metal +603 -303
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +663 -56
- data/vendor/tmp/llama.cpp/ggml-quants.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +341 -469
- data/vendor/tmp/llama.cpp/ggml-sycl.h +19 -4
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +37199 -14939
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +335 -307
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -11
- data/vendor/tmp/llama.cpp/ggml.c +229 -107
- data/vendor/tmp/llama.cpp/ggml.h +11 -5
- data/vendor/tmp/llama.cpp/llama.cpp +2136 -464
- data/vendor/tmp/llama.cpp/llama.h +86 -23
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1651 -0
- data/vendor/tmp/llama.cpp/unicode-data.h +16 -0
- data/vendor/tmp/llama.cpp/unicode.cpp +8 -1403
- data/vendor/tmp/llama.cpp/unicode.h +2 -0
- metadata +5 -3
@@ -17,29 +17,17 @@ extern "C" {
|
|
17
17
|
|
18
18
|
#define GGML_CUDA_MAX_DEVICES 16
|
19
19
|
|
20
|
-
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
21
|
-
GGML_API GGML_CALL void ggml_init_cublas(void);
|
22
|
-
|
23
|
-
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
24
|
-
GGML_API GGML_CALL bool ggml_cublas_loaded(void);
|
25
|
-
|
26
|
-
GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
|
27
|
-
GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr);
|
28
|
-
|
29
|
-
GGML_API GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
30
|
-
GGML_API GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
31
|
-
|
32
|
-
GGML_API GGML_CALL int ggml_cuda_get_device_count(void);
|
33
|
-
GGML_API GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
34
|
-
|
35
20
|
// backend API
|
36
21
|
GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
|
37
22
|
|
38
23
|
GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
|
39
24
|
|
25
|
+
// device buffer
|
40
26
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
27
|
+
|
41
28
|
// split tensor buffer that splits matrices by rows across multiple devices
|
42
29
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
|
30
|
+
|
43
31
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
44
32
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
45
33
|
|
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
|
|
47
35
|
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
48
36
|
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
49
37
|
|
38
|
+
GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
|
39
|
+
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
|
40
|
+
|
50
41
|
#ifdef __cplusplus
|
51
42
|
}
|
52
43
|
#endif
|
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1430
1430
|
struct ggml_tensor * dst = gf->nodes[i];
|
1431
1431
|
GGML_ASSERT(dst->data != nullptr);
|
1432
1432
|
|
1433
|
+
if (ggml_is_empty(dst)) {
|
1434
|
+
continue;
|
1435
|
+
}
|
1436
|
+
|
1433
1437
|
switch (dst->op) {
|
1434
1438
|
case GGML_OP_NONE:
|
1435
1439
|
case GGML_OP_RESHAPE:
|
@@ -1951,6 +1955,7 @@ static struct ggml_backend_i kompute_backend_i = {
|
|
1951
1955
|
/* .graph_plan_compute = */ NULL,
|
1952
1956
|
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
|
1953
1957
|
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
1958
|
+
/* .offload_op = */ NULL,
|
1954
1959
|
/* .event_new = */ NULL,
|
1955
1960
|
/* .event_free = */ NULL,
|
1956
1961
|
/* .event_record = */ NULL,
|
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
|
|
64
64
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
65
65
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
66
66
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
67
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
|
67
68
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
68
69
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
69
70
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
|
|
91
92
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
92
93
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
93
94
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
|
94
96
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
95
97
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
96
98
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
|
|
114
116
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
115
117
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
116
118
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
119
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
117
120
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
118
121
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
119
122
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
|
|
134
137
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
135
138
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
136
139
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
140
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
137
141
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
138
142
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
139
143
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
|
|
154
158
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
155
159
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
156
160
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
161
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
157
162
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
158
163
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
159
164
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
@@ -173,8 +178,9 @@ enum ggml_metal_kernel_type {
|
|
173
178
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
174
179
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
175
180
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
176
|
-
|
177
|
-
|
181
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
182
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
183
|
+
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
178
184
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
179
185
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
180
186
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
@@ -489,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
489
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
490
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
491
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
492
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
493
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
494
501
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
@@ -516,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
516
523
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
517
524
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
518
525
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
519
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
520
528
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
521
529
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
@@ -539,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
539
547
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
540
548
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
541
549
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
550
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
542
551
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
543
552
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
544
553
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
@@ -559,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
559
568
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
560
569
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
561
570
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
571
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
562
572
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
563
573
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
564
574
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
@@ -579,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
579
589
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
580
590
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
581
591
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
592
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
582
593
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
583
594
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
584
595
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
@@ -598,8 +609,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
598
609
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
599
610
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
600
611
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
601
|
-
|
602
|
-
|
612
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
613
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
614
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
603
615
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
604
616
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
605
617
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
@@ -739,6 +751,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
739
751
|
case GGML_TYPE_Q8_0:
|
740
752
|
case GGML_TYPE_Q4_0:
|
741
753
|
case GGML_TYPE_Q4_1:
|
754
|
+
case GGML_TYPE_Q5_0:
|
755
|
+
case GGML_TYPE_Q5_1:
|
756
|
+
case GGML_TYPE_IQ4_NL:
|
742
757
|
return true;
|
743
758
|
default:
|
744
759
|
return false;
|
@@ -832,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
832
847
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
833
848
|
struct ggml_tensor * dst = gf->nodes[i];
|
834
849
|
|
850
|
+
if (ggml_is_empty(dst)) {
|
851
|
+
continue;
|
852
|
+
}
|
853
|
+
|
835
854
|
switch (dst->op) {
|
836
855
|
case GGML_OP_NONE:
|
837
856
|
case GGML_OP_RESHAPE:
|
@@ -1387,6 +1406,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1387
1406
|
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
1388
1407
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1389
1408
|
|
1409
|
+
// some Metal matrix data types require aligned pointers
|
1410
|
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
1411
|
+
switch (src0->type) {
|
1412
|
+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
1413
|
+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
1414
|
+
default: break;
|
1415
|
+
}
|
1416
|
+
|
1390
1417
|
id<MTLComputePipelineState> pipeline = nil;
|
1391
1418
|
|
1392
1419
|
switch (src0->type) {
|
@@ -1408,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1408
1435
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
1409
1436
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
1410
1437
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
1438
|
+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
1411
1439
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
1412
1440
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
1413
1441
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
@@ -1562,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1562
1590
|
nth1 = 16;
|
1563
1591
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
1564
1592
|
} break;
|
1593
|
+
case GGML_TYPE_IQ1_M:
|
1594
|
+
{
|
1595
|
+
nth0 = 4;
|
1596
|
+
nth1 = 16;
|
1597
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
1598
|
+
} break;
|
1565
1599
|
case GGML_TYPE_IQ4_NL:
|
1566
1600
|
{
|
1567
1601
|
nth0 = 4;
|
@@ -1606,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1606
1640
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1607
1641
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1608
1642
|
|
1609
|
-
if (src0t == GGML_TYPE_Q4_0
|
1610
|
-
src0t ==
|
1611
|
-
src0t ==
|
1643
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
1644
|
+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
1645
|
+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
1612
1646
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1613
1647
|
}
|
1614
1648
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
@@ -1651,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1651
1685
|
{
|
1652
1686
|
//GGML_ASSERT(ne00 == ne10);
|
1653
1687
|
//GGML_ASSERT(ne03 == ne13);
|
1654
|
-
|
1655
|
-
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1656
|
-
|
1657
|
-
const int n_as = ((int32_t *) dst->op_params)[1];
|
1658
|
-
|
1659
|
-
// TODO: make this more general
|
1660
|
-
GGML_ASSERT(n_as <= 8);
|
1688
|
+
const int n_as = src0->ne[2];
|
1661
1689
|
|
1662
1690
|
// max size of the src1ids array in the kernel shared buffer
|
1663
1691
|
GGML_ASSERT(ne11 <= 4096);
|
1664
1692
|
|
1665
|
-
|
1666
|
-
const int64_t
|
1667
|
-
const int64_t
|
1668
|
-
const int64_t
|
1693
|
+
// src2 = ids
|
1694
|
+
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
|
1695
|
+
const int64_t ne21 = src2->ne[1];
|
1696
|
+
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1697
|
+
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
1698
|
+
|
1699
|
+
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
1700
|
+
const uint64_t nb21 = src2->nb[1];
|
1701
|
+
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
1702
|
+
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
1669
1703
|
|
1670
|
-
const
|
1671
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
1672
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
1673
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
1704
|
+
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
1674
1705
|
|
1675
|
-
|
1706
|
+
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
1676
1707
|
|
1677
|
-
GGML_ASSERT(!ggml_is_transposed(
|
1708
|
+
GGML_ASSERT(!ggml_is_transposed(src0));
|
1678
1709
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
1679
1710
|
|
1680
1711
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1681
1712
|
|
1682
|
-
const uint r2 = ne12/ne22;
|
1683
|
-
const uint r3 = ne13/ne23;
|
1684
|
-
|
1685
1713
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1686
1714
|
// to the matrix-vector kernel
|
1687
1715
|
int ne11_mm_min = n_as;
|
@@ -1689,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1689
1717
|
const int idx = ((int32_t *) dst->op_params)[0];
|
1690
1718
|
|
1691
1719
|
// batch size
|
1692
|
-
GGML_ASSERT(
|
1720
|
+
GGML_ASSERT(ne21 == ne11); // ?
|
1721
|
+
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
1722
|
+
const uint r2 = 1;
|
1723
|
+
const uint r3 = 1;
|
1693
1724
|
|
1694
1725
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1695
1726
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
@@ -1698,12 +1729,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1698
1729
|
// indirect matrix multiplication
|
1699
1730
|
// !!!
|
1700
1731
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1701
|
-
|
1732
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
1702
1733
|
ne11 > ne11_mm_min) {
|
1703
1734
|
|
1735
|
+
// some Metal matrix data types require aligned pointers
|
1736
|
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
1737
|
+
switch (src0->type) {
|
1738
|
+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
1739
|
+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
1740
|
+
default: break;
|
1741
|
+
}
|
1742
|
+
|
1704
1743
|
id<MTLComputePipelineState> pipeline = nil;
|
1705
1744
|
|
1706
|
-
switch (
|
1745
|
+
switch (src0->type) {
|
1707
1746
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
1708
1747
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
1709
1748
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
@@ -1722,6 +1761,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1722
1761
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
1723
1762
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
1724
1763
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
1764
|
+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
1725
1765
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
1726
1766
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
1727
1767
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
@@ -1731,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1731
1771
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1732
1772
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1733
1773
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1734
|
-
[encoder
|
1735
|
-
[encoder setBytes:&
|
1736
|
-
[encoder setBytes:&
|
1737
|
-
[encoder setBytes:&
|
1738
|
-
[encoder setBytes:&
|
1739
|
-
[encoder setBytes:&
|
1740
|
-
[encoder setBytes:&
|
1741
|
-
[encoder setBytes:&
|
1742
|
-
[encoder setBytes:&
|
1743
|
-
[encoder setBytes:&
|
1744
|
-
[encoder setBytes:&
|
1745
|
-
[encoder setBytes:&
|
1746
|
-
[encoder setBytes:&
|
1747
|
-
[encoder setBytes:&
|
1748
|
-
[encoder setBytes:&
|
1749
|
-
[encoder setBytes:&
|
1750
|
-
|
1751
|
-
for (int j = 0; j < 8; ++j) {
|
1752
|
-
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
1753
|
-
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
1754
|
-
|
1755
|
-
size_t offs_src_cur = 0;
|
1756
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1757
|
-
|
1758
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1759
|
-
}
|
1774
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1775
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
1776
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
1777
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
1778
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1779
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1780
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
|
1781
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
|
1782
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
1783
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
1784
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
|
1785
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
|
1786
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
|
1787
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
|
1788
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1789
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1790
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:19];
|
1760
1791
|
|
1761
1792
|
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
1762
1793
|
|
1763
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (
|
1794
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1764
1795
|
} else {
|
1765
1796
|
int nth0 = 32;
|
1766
1797
|
int nth1 = 1;
|
@@ -1770,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1770
1801
|
id<MTLComputePipelineState> pipeline = nil;
|
1771
1802
|
|
1772
1803
|
// use custom matrix x vector kernel
|
1773
|
-
switch (
|
1804
|
+
switch (src0t) {
|
1774
1805
|
case GGML_TYPE_F32:
|
1775
1806
|
{
|
1776
1807
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
@@ -1879,6 +1910,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1879
1910
|
nth1 = 16;
|
1880
1911
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
1881
1912
|
} break;
|
1913
|
+
case GGML_TYPE_IQ1_M:
|
1914
|
+
{
|
1915
|
+
nth0 = 4;
|
1916
|
+
nth1 = 16;
|
1917
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
1918
|
+
} break;
|
1882
1919
|
case GGML_TYPE_IQ4_NL:
|
1883
1920
|
{
|
1884
1921
|
nth0 = 4;
|
@@ -1898,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1898
1935
|
}
|
1899
1936
|
};
|
1900
1937
|
|
1901
|
-
if (ggml_is_quantized(
|
1902
|
-
GGML_ASSERT(
|
1938
|
+
if (ggml_is_quantized(src0t)) {
|
1939
|
+
GGML_ASSERT(ne00 >= nth0*nth1);
|
1903
1940
|
}
|
1904
1941
|
|
1905
1942
|
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
@@ -1908,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1908
1945
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1909
1946
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1910
1947
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1911
|
-
[encoder
|
1912
|
-
[encoder setBytes:&
|
1913
|
-
[encoder setBytes:&
|
1914
|
-
[encoder setBytes:&
|
1915
|
-
[encoder setBytes:&
|
1916
|
-
[encoder setBytes:&
|
1917
|
-
[encoder setBytes:&
|
1918
|
-
[encoder setBytes:&
|
1919
|
-
[encoder setBytes:&
|
1920
|
-
[encoder setBytes:&
|
1921
|
-
[encoder setBytes:&
|
1922
|
-
[encoder setBytes:&
|
1923
|
-
[encoder setBytes:&
|
1924
|
-
[encoder setBytes:&
|
1925
|
-
[encoder setBytes:&
|
1926
|
-
[encoder setBytes:&
|
1927
|
-
[encoder setBytes:&
|
1928
|
-
[encoder setBytes:&
|
1929
|
-
[encoder setBytes:&
|
1930
|
-
[encoder setBytes:&
|
1931
|
-
|
1932
|
-
|
1933
|
-
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
1937
|
-
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
1938
|
-
|
1939
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1940
|
-
}
|
1941
|
-
|
1942
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1943
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1944
|
-
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
1945
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1948
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1949
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
1950
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
1951
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
|
1952
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
|
1953
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
|
1954
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
1955
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
1956
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1957
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
1958
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1959
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1960
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1961
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1962
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1963
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1964
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
|
1965
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
|
1966
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
|
1967
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
|
1968
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
|
1969
|
+
|
1970
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
1971
|
+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
1972
|
+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
1973
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1946
1974
|
}
|
1947
|
-
else if (
|
1948
|
-
const int mem_size =
|
1975
|
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1976
|
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1949
1977
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1950
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1978
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1951
1979
|
}
|
1952
|
-
else if (
|
1953
|
-
const int mem_size =
|
1980
|
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1981
|
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1954
1982
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1955
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1983
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1956
1984
|
}
|
1957
|
-
else if (
|
1985
|
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1958
1986
|
const int mem_size = 32*sizeof(float);
|
1959
1987
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1960
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1988
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1961
1989
|
}
|
1962
|
-
else if (
|
1963
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1990
|
+
else if (src0t == GGML_TYPE_Q4_K) {
|
1991
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1964
1992
|
}
|
1965
|
-
else if (
|
1993
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
1966
1994
|
#ifdef GGML_QKK_64
|
1967
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1995
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1968
1996
|
#else
|
1969
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1997
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1970
1998
|
#endif
|
1971
1999
|
}
|
1972
|
-
else if (
|
1973
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
2000
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
2001
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1974
2002
|
}
|
1975
|
-
else if (
|
1976
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
2003
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
2004
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1977
2005
|
} else {
|
1978
2006
|
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1979
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
2007
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1980
2008
|
}
|
1981
2009
|
}
|
1982
2010
|
} break;
|
@@ -2003,6 +2031,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2003
2031
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
2004
2032
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
2005
2033
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
2034
|
+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
|
2006
2035
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
2007
2036
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
2008
2037
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
@@ -2382,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2382
2411
|
|
2383
2412
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
2384
2413
|
|
2414
|
+
// bitonic sort requires the number of elements to be power of 2
|
2415
|
+
int64_t ne00_padded = 1;
|
2416
|
+
while (ne00_padded < ne00) {
|
2417
|
+
ne00_padded *= 2;
|
2418
|
+
}
|
2419
|
+
|
2420
|
+
// Metal kernels require the buffer size to be multiple of 16 bytes
|
2421
|
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
2422
|
+
const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
|
2423
|
+
|
2385
2424
|
id<MTLComputePipelineState> pipeline = nil;
|
2386
2425
|
|
2387
2426
|
switch (order) {
|
@@ -2391,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2391
2430
|
};
|
2392
2431
|
|
2393
2432
|
[encoder setComputePipelineState:pipeline];
|
2394
|
-
[encoder setBuffer:id_src0
|
2395
|
-
[encoder setBuffer:id_dst
|
2396
|
-
[encoder setBytes:&ne00
|
2433
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2434
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2435
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2436
|
+
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
2437
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2397
2438
|
|
2398
|
-
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(
|
2439
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
2399
2440
|
} break;
|
2400
2441
|
case GGML_OP_LEAKY_RELU:
|
2401
2442
|
{
|
@@ -2431,13 +2472,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2431
2472
|
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
2432
2473
|
|
2433
2474
|
switch (dstt) {
|
2434
|
-
case GGML_TYPE_F16:
|
2435
|
-
case GGML_TYPE_F32:
|
2436
|
-
case GGML_TYPE_Q8_0:
|
2437
|
-
case GGML_TYPE_Q4_0:
|
2438
|
-
case GGML_TYPE_Q4_1:
|
2439
|
-
|
2440
|
-
|
2475
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
2476
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
2477
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
2478
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
2479
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
2480
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
2481
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
2482
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
|
2441
2483
|
default: GGML_ASSERT(false && "not implemented");
|
2442
2484
|
};
|
2443
2485
|
} break;
|
@@ -2837,6 +2879,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
|
2837
2879
|
/* .graph_plan_compute = */ NULL,
|
2838
2880
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
2839
2881
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
2882
|
+
/* .offload_op = */ NULL,
|
2840
2883
|
/* .event_new = */ NULL,
|
2841
2884
|
/* .event_free = */ NULL,
|
2842
2885
|
/* .event_record = */ NULL,
|