llama_cpp 0.12.7 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +131 -288
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +29 -29
- data/vendor/tmp/llama.cpp/Makefile +10 -6
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -23
- data/vendor/tmp/llama.cpp/ggml-backend.h +17 -16
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +159 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +1195 -139
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +27 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +1971 -271
- data/vendor/tmp/llama.cpp/ggml-quants.h +52 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3586 -1201
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1391 -825
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +545 -210
- data/vendor/tmp/llama.cpp/ggml.h +65 -23
- data/vendor/tmp/llama.cpp/llama.cpp +1458 -763
- data/vendor/tmp/llama.cpp/llama.h +81 -75
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
|
|
1927
1927
|
return ggml_backend_kompute_buffer_type(ctx->device);
|
1928
1928
|
}
|
1929
1929
|
|
1930
|
-
static
|
1930
|
+
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
1931
1931
|
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
1932
1932
|
ggml_vk_graph_compute(ctx, cgraph);
|
1933
|
-
return
|
1933
|
+
return GGML_STATUS_SUCCESS;
|
1934
1934
|
}
|
1935
1935
|
|
1936
1936
|
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
|
|
1953
1953
|
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
1954
1954
|
};
|
1955
1955
|
|
1956
|
+
static ggml_guid_t ggml_backend_kompute_guid() {
|
1957
|
+
static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
|
1958
|
+
return &guid;
|
1959
|
+
}
|
1960
|
+
|
1956
1961
|
ggml_backend_t ggml_backend_kompute_init(int device) {
|
1957
1962
|
GGML_ASSERT(s_kompute_context == nullptr);
|
1958
1963
|
s_kompute_context = new ggml_kompute_context(device);
|
1959
1964
|
|
1960
1965
|
ggml_backend_t kompute_backend = new ggml_backend {
|
1966
|
+
/* .guid = */ ggml_backend_kompute_guid(),
|
1961
1967
|
/* .interface = */ kompute_backend_i,
|
1962
1968
|
/* .context = */ s_kompute_context,
|
1963
1969
|
};
|
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
1966
1972
|
}
|
1967
1973
|
|
1968
1974
|
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
1969
|
-
return backend && backend->
|
1975
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
1970
1976
|
}
|
1971
1977
|
|
1972
1978
|
static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
|
@@ -61,8 +61,11 @@ enum ggml_metal_kernel_type {
|
|
61
61
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
62
62
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
63
63
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
64
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
65
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
64
66
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
65
67
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
68
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
66
69
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
67
70
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
68
71
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
@@ -85,8 +88,11 @@ enum ggml_metal_kernel_type {
|
|
85
88
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
86
89
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
87
90
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
91
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
92
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
88
93
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
89
94
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
90
96
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
91
97
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
92
98
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
@@ -105,8 +111,11 @@ enum ggml_metal_kernel_type {
|
|
105
111
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
106
112
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
107
113
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
114
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
115
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
108
116
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
109
117
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
110
119
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
111
120
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
112
121
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
@@ -122,8 +131,11 @@ enum ggml_metal_kernel_type {
|
|
122
131
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
123
132
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
124
133
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
134
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
135
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
125
136
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
126
137
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
138
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
127
139
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
128
140
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
129
141
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
@@ -139,8 +151,11 @@ enum ggml_metal_kernel_type {
|
|
139
151
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
140
152
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
141
153
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
154
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
155
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
142
156
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
143
157
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
158
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
144
159
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
145
160
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
146
161
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
@@ -148,6 +163,8 @@ enum ggml_metal_kernel_type {
|
|
148
163
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
149
164
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
150
165
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
166
|
+
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
167
|
+
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
151
168
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
152
169
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
153
170
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
@@ -452,8 +469,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
452
469
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
453
470
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
454
471
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
472
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
473
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
455
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
456
475
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
476
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
457
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
458
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
459
479
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
@@ -476,8 +496,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
476
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
477
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
478
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
479
501
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
480
502
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
481
504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
482
505
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
483
506
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -496,8 +519,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
496
519
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
497
520
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
498
521
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
522
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
523
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
499
524
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
500
525
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
501
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
502
528
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
503
529
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -513,8 +539,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
513
539
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
514
540
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
515
541
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
542
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
543
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
516
544
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
517
545
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
546
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
518
547
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
519
548
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
520
549
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -530,8 +559,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
530
559
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
531
560
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
532
561
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
562
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
563
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
533
564
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
534
565
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
566
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
535
567
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
536
568
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
537
569
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
@@ -539,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
539
571
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
540
572
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
541
573
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
574
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
575
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
542
576
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
543
577
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
544
578
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
@@ -667,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
667
701
|
return false;
|
668
702
|
case GGML_OP_UPSCALE:
|
669
703
|
case GGML_OP_PAD:
|
704
|
+
case GGML_OP_ARANGE:
|
705
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
670
706
|
case GGML_OP_ARGSORT:
|
671
707
|
case GGML_OP_LEAKY_RELU:
|
672
708
|
return true;
|
@@ -712,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
712
748
|
}
|
713
749
|
}
|
714
750
|
|
715
|
-
static
|
751
|
+
static enum ggml_status ggml_metal_graph_compute(
|
716
752
|
struct ggml_metal_context * ctx,
|
717
753
|
struct ggml_cgraph * gf) {
|
718
754
|
|
@@ -1061,7 +1097,8 @@ static bool ggml_metal_graph_compute(
|
|
1061
1097
|
{
|
1062
1098
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
1063
1099
|
|
1064
|
-
|
1100
|
+
float scale;
|
1101
|
+
memcpy(&scale, dst->op_params, sizeof(scale));
|
1065
1102
|
|
1066
1103
|
int64_t n = ggml_nelements(dst);
|
1067
1104
|
|
@@ -1220,11 +1257,15 @@ static bool ggml_metal_graph_compute(
|
|
1220
1257
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
1221
1258
|
}
|
1222
1259
|
|
1223
|
-
|
1224
|
-
|
1260
|
+
float scale;
|
1261
|
+
float max_bias;
|
1262
|
+
|
1263
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
1264
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
1225
1265
|
|
1226
1266
|
const int64_t nrows_x = ggml_nrows(src0);
|
1227
1267
|
const int64_t nrows_y = src0->ne[1];
|
1268
|
+
|
1228
1269
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
1229
1270
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
1230
1271
|
|
@@ -1347,8 +1388,11 @@ static bool ggml_metal_graph_compute(
|
|
1347
1388
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1348
1389
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1349
1390
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1391
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
1392
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
1350
1393
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
1351
1394
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
1395
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
1352
1396
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
1353
1397
|
}
|
1354
1398
|
|
@@ -1483,6 +1527,18 @@ static bool ggml_metal_graph_compute(
|
|
1483
1527
|
nth1 = 16;
|
1484
1528
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
1485
1529
|
} break;
|
1530
|
+
case GGML_TYPE_IQ3_S:
|
1531
|
+
{
|
1532
|
+
nth0 = 4;
|
1533
|
+
nth1 = 16;
|
1534
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
1535
|
+
} break;
|
1536
|
+
case GGML_TYPE_IQ2_S:
|
1537
|
+
{
|
1538
|
+
nth0 = 4;
|
1539
|
+
nth1 = 16;
|
1540
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
1541
|
+
} break;
|
1486
1542
|
case GGML_TYPE_IQ1_S:
|
1487
1543
|
{
|
1488
1544
|
nth0 = 4;
|
@@ -1495,6 +1551,12 @@ static bool ggml_metal_graph_compute(
|
|
1495
1551
|
nth1 = 16;
|
1496
1552
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
1497
1553
|
} break;
|
1554
|
+
case GGML_TYPE_IQ4_XS:
|
1555
|
+
{
|
1556
|
+
nth0 = 4;
|
1557
|
+
nth1 = 16;
|
1558
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
1559
|
+
} break;
|
1498
1560
|
default:
|
1499
1561
|
{
|
1500
1562
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
@@ -1527,9 +1589,9 @@ static bool ggml_metal_graph_compute(
|
|
1527
1589
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1528
1590
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1529
1591
|
|
1530
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1
|
1531
|
-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1
|
1532
|
-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S
|
1592
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1593
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1594
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
1533
1595
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1534
1596
|
}
|
1535
1597
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
@@ -1537,12 +1599,12 @@ static bool ggml_metal_graph_compute(
|
|
1537
1599
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1538
1600
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1539
1601
|
}
|
1540
|
-
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
1541
|
-
const int mem_size = 256*4+128;
|
1602
|
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1603
|
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1542
1604
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1543
1605
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1544
1606
|
}
|
1545
|
-
else if (src0t == GGML_TYPE_IQ4_NL) {
|
1607
|
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1546
1608
|
const int mem_size = 32*sizeof(float);
|
1547
1609
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1548
1610
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1640,8 +1702,11 @@ static bool ggml_metal_graph_compute(
|
|
1640
1702
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
1641
1703
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
1642
1704
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
1705
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
1706
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
1643
1707
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
1644
1708
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
1709
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
1645
1710
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1646
1711
|
}
|
1647
1712
|
|
@@ -1779,6 +1844,18 @@ static bool ggml_metal_graph_compute(
|
|
1779
1844
|
nth1 = 16;
|
1780
1845
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
1781
1846
|
} break;
|
1847
|
+
case GGML_TYPE_IQ3_S:
|
1848
|
+
{
|
1849
|
+
nth0 = 4;
|
1850
|
+
nth1 = 16;
|
1851
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
1852
|
+
} break;
|
1853
|
+
case GGML_TYPE_IQ2_S:
|
1854
|
+
{
|
1855
|
+
nth0 = 4;
|
1856
|
+
nth1 = 16;
|
1857
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
1858
|
+
} break;
|
1782
1859
|
case GGML_TYPE_IQ1_S:
|
1783
1860
|
{
|
1784
1861
|
nth0 = 4;
|
@@ -1791,6 +1868,12 @@ static bool ggml_metal_graph_compute(
|
|
1791
1868
|
nth1 = 16;
|
1792
1869
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
1793
1870
|
} break;
|
1871
|
+
case GGML_TYPE_IQ4_XS:
|
1872
|
+
{
|
1873
|
+
nth0 = 4;
|
1874
|
+
nth1 = 16;
|
1875
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
1876
|
+
} break;
|
1794
1877
|
default:
|
1795
1878
|
{
|
1796
1879
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
@@ -1839,9 +1922,9 @@ static bool ggml_metal_graph_compute(
|
|
1839
1922
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1840
1923
|
}
|
1841
1924
|
|
1842
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1
|
1843
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1
|
1844
|
-
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S
|
1925
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1926
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1927
|
+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
1845
1928
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1846
1929
|
}
|
1847
1930
|
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
@@ -1849,12 +1932,12 @@ static bool ggml_metal_graph_compute(
|
|
1849
1932
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1850
1933
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1851
1934
|
}
|
1852
|
-
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
1853
|
-
const int mem_size = 256*4+128;
|
1935
|
+
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
1936
|
+
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1854
1937
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1855
1938
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1856
1939
|
}
|
1857
|
-
else if (src2t == GGML_TYPE_IQ4_NL) {
|
1940
|
+
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
1858
1941
|
const int mem_size = 32*sizeof(float);
|
1859
1942
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1860
1943
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1900,8 +1983,11 @@ static bool ggml_metal_graph_compute(
|
|
1900
1983
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
1901
1984
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
1902
1985
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
1986
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
1987
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
1903
1988
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
1904
1989
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
1990
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
1905
1991
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
1906
1992
|
default: GGML_ASSERT(false && "not implemented");
|
1907
1993
|
}
|
@@ -2011,6 +2097,7 @@ static bool ggml_metal_graph_compute(
|
|
2011
2097
|
|
2012
2098
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2013
2099
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
2100
|
+
|
2014
2101
|
float max_bias;
|
2015
2102
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2016
2103
|
|
@@ -2225,6 +2312,50 @@ static bool ggml_metal_graph_compute(
|
|
2225
2312
|
|
2226
2313
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2314
|
} break;
|
2315
|
+
case GGML_OP_ARANGE:
|
2316
|
+
{
|
2317
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
2318
|
+
|
2319
|
+
float start;
|
2320
|
+
float step;
|
2321
|
+
|
2322
|
+
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
2323
|
+
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
2324
|
+
|
2325
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
2326
|
+
|
2327
|
+
[encoder setComputePipelineState:pipeline];
|
2328
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
2329
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
2330
|
+
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
2331
|
+
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
2332
|
+
|
2333
|
+
const int nth = MIN(1024, ne0);
|
2334
|
+
|
2335
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2336
|
+
} break;
|
2337
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
2338
|
+
{
|
2339
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2340
|
+
|
2341
|
+
const int dim = dst->op_params[0];
|
2342
|
+
const int max_period = dst->op_params[1];
|
2343
|
+
|
2344
|
+
const int half = dim / 2;
|
2345
|
+
|
2346
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
2347
|
+
|
2348
|
+
[encoder setComputePipelineState:pipeline];
|
2349
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2351
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
2352
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
2353
|
+
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
2354
|
+
|
2355
|
+
const int nth = MIN(1024, half);
|
2356
|
+
|
2357
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2358
|
+
} break;
|
2228
2359
|
case GGML_OP_ARGSORT:
|
2229
2360
|
{
|
2230
2361
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
@@ -2237,8 +2368,8 @@ static bool ggml_metal_graph_compute(
|
|
2237
2368
|
id<MTLComputePipelineState> pipeline = nil;
|
2238
2369
|
|
2239
2370
|
switch (order) {
|
2240
|
-
case
|
2241
|
-
case
|
2371
|
+
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
2372
|
+
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
2242
2373
|
default: GGML_ASSERT(false);
|
2243
2374
|
};
|
2244
2375
|
|
@@ -2353,7 +2484,7 @@ static bool ggml_metal_graph_compute(
|
|
2353
2484
|
MTLCommandBufferStatus status = [command_buffer status];
|
2354
2485
|
if (status != MTLCommandBufferStatusCompleted) {
|
2355
2486
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2356
|
-
return
|
2487
|
+
return GGML_STATUS_FAILED;
|
2357
2488
|
}
|
2358
2489
|
}
|
2359
2490
|
|
@@ -2362,7 +2493,7 @@ static bool ggml_metal_graph_compute(
|
|
2362
2493
|
}
|
2363
2494
|
|
2364
2495
|
}
|
2365
|
-
return
|
2496
|
+
return GGML_STATUS_SUCCESS;
|
2366
2497
|
}
|
2367
2498
|
|
2368
2499
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -2664,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
|
|
2664
2795
|
UNUSED(backend);
|
2665
2796
|
}
|
2666
2797
|
|
2667
|
-
GGML_CALL static
|
2798
|
+
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
2668
2799
|
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
2669
2800
|
|
2670
2801
|
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
@@ -2696,6 +2827,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
|
|
2696
2827
|
ggml_metal_log_user_data = user_data;
|
2697
2828
|
}
|
2698
2829
|
|
2830
|
+
static ggml_guid_t ggml_backend_metal_guid(void) {
|
2831
|
+
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
2832
|
+
return &guid;
|
2833
|
+
}
|
2834
|
+
|
2699
2835
|
ggml_backend_t ggml_backend_metal_init(void) {
|
2700
2836
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2701
2837
|
|
@@ -2706,6 +2842,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2706
2842
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
2707
2843
|
|
2708
2844
|
*metal_backend = (struct ggml_backend) {
|
2845
|
+
/* .guid = */ ggml_backend_metal_guid(),
|
2709
2846
|
/* .interface = */ ggml_backend_metal_i,
|
2710
2847
|
/* .context = */ ctx,
|
2711
2848
|
};
|
@@ -2714,7 +2851,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2714
2851
|
}
|
2715
2852
|
|
2716
2853
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
2717
|
-
return backend && backend->
|
2854
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
2718
2855
|
}
|
2719
2856
|
|
2720
2857
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|