llama_cpp 0.12.7 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +131 -288
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +29 -29
- data/vendor/tmp/llama.cpp/Makefile +10 -6
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -23
- data/vendor/tmp/llama.cpp/ggml-backend.h +17 -16
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +159 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +1195 -139
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +27 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +1971 -271
- data/vendor/tmp/llama.cpp/ggml-quants.h +52 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3586 -1201
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1391 -825
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +545 -210
- data/vendor/tmp/llama.cpp/ggml.h +65 -23
- data/vendor/tmp/llama.cpp/llama.cpp +1458 -763
- data/vendor/tmp/llama.cpp/llama.h +81 -75
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
|
|
1927
1927
|
return ggml_backend_kompute_buffer_type(ctx->device);
|
1928
1928
|
}
|
1929
1929
|
|
1930
|
-
static
|
1930
|
+
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
1931
1931
|
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
1932
1932
|
ggml_vk_graph_compute(ctx, cgraph);
|
1933
|
-
return
|
1933
|
+
return GGML_STATUS_SUCCESS;
|
1934
1934
|
}
|
1935
1935
|
|
1936
1936
|
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
|
|
1953
1953
|
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
1954
1954
|
};
|
1955
1955
|
|
1956
|
+
static ggml_guid_t ggml_backend_kompute_guid() {
|
1957
|
+
static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
|
1958
|
+
return &guid;
|
1959
|
+
}
|
1960
|
+
|
1956
1961
|
ggml_backend_t ggml_backend_kompute_init(int device) {
|
1957
1962
|
GGML_ASSERT(s_kompute_context == nullptr);
|
1958
1963
|
s_kompute_context = new ggml_kompute_context(device);
|
1959
1964
|
|
1960
1965
|
ggml_backend_t kompute_backend = new ggml_backend {
|
1966
|
+
/* .guid = */ ggml_backend_kompute_guid(),
|
1961
1967
|
/* .interface = */ kompute_backend_i,
|
1962
1968
|
/* .context = */ s_kompute_context,
|
1963
1969
|
};
|
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
1966
1972
|
}
|
1967
1973
|
|
1968
1974
|
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
1969
|
-
return backend && backend->
|
1975
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
1970
1976
|
}
|
1971
1977
|
|
1972
1978
|
static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
|
@@ -61,8 +61,11 @@ enum ggml_metal_kernel_type {
|
|
61
61
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
62
62
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
63
63
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
64
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
65
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
64
66
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
65
67
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
68
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
66
69
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
67
70
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
68
71
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
@@ -85,8 +88,11 @@ enum ggml_metal_kernel_type {
|
|
85
88
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
86
89
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
87
90
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
91
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
92
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
88
93
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
89
94
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
90
96
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
91
97
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
92
98
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
@@ -105,8 +111,11 @@ enum ggml_metal_kernel_type {
|
|
105
111
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
106
112
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
107
113
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
114
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
115
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
108
116
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
109
117
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
110
119
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
111
120
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
112
121
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
@@ -122,8 +131,11 @@ enum ggml_metal_kernel_type {
|
|
122
131
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
123
132
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
124
133
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
134
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
135
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
125
136
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
126
137
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
138
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
127
139
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
128
140
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
129
141
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
@@ -139,8 +151,11 @@ enum ggml_metal_kernel_type {
|
|
139
151
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
140
152
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
141
153
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
154
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
155
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
142
156
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
143
157
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
158
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
144
159
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
145
160
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
146
161
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
@@ -148,6 +163,8 @@ enum ggml_metal_kernel_type {
|
|
148
163
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
149
164
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
150
165
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
166
|
+
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
167
|
+
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
151
168
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
152
169
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
153
170
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
@@ -452,8 +469,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
452
469
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
453
470
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
454
471
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
472
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
473
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
455
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
456
475
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
476
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
457
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
458
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
459
479
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
@@ -476,8 +496,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
476
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
477
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
478
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
479
501
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
480
502
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
481
504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
482
505
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
483
506
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -496,8 +519,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
496
519
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
497
520
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
498
521
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
522
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
523
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
499
524
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
500
525
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
501
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
502
528
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
503
529
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -513,8 +539,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
513
539
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
514
540
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
515
541
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
542
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
543
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
516
544
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
517
545
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
546
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
518
547
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
519
548
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
520
549
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
@@ -530,8 +559,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
530
559
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
531
560
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
532
561
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
562
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
563
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
533
564
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
534
565
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
566
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
535
567
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
536
568
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
537
569
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
@@ -539,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
539
571
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
540
572
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
541
573
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
574
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
575
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
542
576
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
543
577
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
544
578
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
@@ -667,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
667
701
|
return false;
|
668
702
|
case GGML_OP_UPSCALE:
|
669
703
|
case GGML_OP_PAD:
|
704
|
+
case GGML_OP_ARANGE:
|
705
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
670
706
|
case GGML_OP_ARGSORT:
|
671
707
|
case GGML_OP_LEAKY_RELU:
|
672
708
|
return true;
|
@@ -712,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
712
748
|
}
|
713
749
|
}
|
714
750
|
|
715
|
-
static
|
751
|
+
static enum ggml_status ggml_metal_graph_compute(
|
716
752
|
struct ggml_metal_context * ctx,
|
717
753
|
struct ggml_cgraph * gf) {
|
718
754
|
|
@@ -1061,7 +1097,8 @@ static bool ggml_metal_graph_compute(
|
|
1061
1097
|
{
|
1062
1098
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
1063
1099
|
|
1064
|
-
|
1100
|
+
float scale;
|
1101
|
+
memcpy(&scale, dst->op_params, sizeof(scale));
|
1065
1102
|
|
1066
1103
|
int64_t n = ggml_nelements(dst);
|
1067
1104
|
|
@@ -1220,11 +1257,15 @@ static bool ggml_metal_graph_compute(
|
|
1220
1257
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
1221
1258
|
}
|
1222
1259
|
|
1223
|
-
|
1224
|
-
|
1260
|
+
float scale;
|
1261
|
+
float max_bias;
|
1262
|
+
|
1263
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
1264
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
1225
1265
|
|
1226
1266
|
const int64_t nrows_x = ggml_nrows(src0);
|
1227
1267
|
const int64_t nrows_y = src0->ne[1];
|
1268
|
+
|
1228
1269
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
1229
1270
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
1230
1271
|
|
@@ -1347,8 +1388,11 @@ static bool ggml_metal_graph_compute(
|
|
1347
1388
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
1348
1389
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
1349
1390
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
1391
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
1392
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
1350
1393
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
1351
1394
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
1395
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
1352
1396
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
1353
1397
|
}
|
1354
1398
|
|
@@ -1483,6 +1527,18 @@ static bool ggml_metal_graph_compute(
|
|
1483
1527
|
nth1 = 16;
|
1484
1528
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
1485
1529
|
} break;
|
1530
|
+
case GGML_TYPE_IQ3_S:
|
1531
|
+
{
|
1532
|
+
nth0 = 4;
|
1533
|
+
nth1 = 16;
|
1534
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
1535
|
+
} break;
|
1536
|
+
case GGML_TYPE_IQ2_S:
|
1537
|
+
{
|
1538
|
+
nth0 = 4;
|
1539
|
+
nth1 = 16;
|
1540
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
1541
|
+
} break;
|
1486
1542
|
case GGML_TYPE_IQ1_S:
|
1487
1543
|
{
|
1488
1544
|
nth0 = 4;
|
@@ -1495,6 +1551,12 @@ static bool ggml_metal_graph_compute(
|
|
1495
1551
|
nth1 = 16;
|
1496
1552
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
1497
1553
|
} break;
|
1554
|
+
case GGML_TYPE_IQ4_XS:
|
1555
|
+
{
|
1556
|
+
nth0 = 4;
|
1557
|
+
nth1 = 16;
|
1558
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
1559
|
+
} break;
|
1498
1560
|
default:
|
1499
1561
|
{
|
1500
1562
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
@@ -1527,9 +1589,9 @@ static bool ggml_metal_graph_compute(
|
|
1527
1589
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1528
1590
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1529
1591
|
|
1530
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1
|
1531
|
-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1
|
1532
|
-
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S
|
1592
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1593
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1594
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
1533
1595
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1534
1596
|
}
|
1535
1597
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
@@ -1537,12 +1599,12 @@ static bool ggml_metal_graph_compute(
|
|
1537
1599
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1538
1600
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1539
1601
|
}
|
1540
|
-
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
1541
|
-
const int mem_size = 256*4+128;
|
1602
|
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1603
|
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1542
1604
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1543
1605
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1544
1606
|
}
|
1545
|
-
else if (src0t == GGML_TYPE_IQ4_NL) {
|
1607
|
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1546
1608
|
const int mem_size = 32*sizeof(float);
|
1547
1609
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1548
1610
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1640,8 +1702,11 @@ static bool ggml_metal_graph_compute(
|
|
1640
1702
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
1641
1703
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
1642
1704
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
1705
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
1706
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
1643
1707
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
1644
1708
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
1709
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
1645
1710
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1646
1711
|
}
|
1647
1712
|
|
@@ -1779,6 +1844,18 @@ static bool ggml_metal_graph_compute(
|
|
1779
1844
|
nth1 = 16;
|
1780
1845
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
1781
1846
|
} break;
|
1847
|
+
case GGML_TYPE_IQ3_S:
|
1848
|
+
{
|
1849
|
+
nth0 = 4;
|
1850
|
+
nth1 = 16;
|
1851
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
1852
|
+
} break;
|
1853
|
+
case GGML_TYPE_IQ2_S:
|
1854
|
+
{
|
1855
|
+
nth0 = 4;
|
1856
|
+
nth1 = 16;
|
1857
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
1858
|
+
} break;
|
1782
1859
|
case GGML_TYPE_IQ1_S:
|
1783
1860
|
{
|
1784
1861
|
nth0 = 4;
|
@@ -1791,6 +1868,12 @@ static bool ggml_metal_graph_compute(
|
|
1791
1868
|
nth1 = 16;
|
1792
1869
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
1793
1870
|
} break;
|
1871
|
+
case GGML_TYPE_IQ4_XS:
|
1872
|
+
{
|
1873
|
+
nth0 = 4;
|
1874
|
+
nth1 = 16;
|
1875
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
1876
|
+
} break;
|
1794
1877
|
default:
|
1795
1878
|
{
|
1796
1879
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
@@ -1839,9 +1922,9 @@ static bool ggml_metal_graph_compute(
|
|
1839
1922
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1840
1923
|
}
|
1841
1924
|
|
1842
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1
|
1843
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1
|
1844
|
-
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S
|
1925
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1926
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1927
|
+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
1845
1928
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1846
1929
|
}
|
1847
1930
|
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
@@ -1849,12 +1932,12 @@ static bool ggml_metal_graph_compute(
|
|
1849
1932
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1850
1933
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1851
1934
|
}
|
1852
|
-
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
1853
|
-
const int mem_size = 256*4+128;
|
1935
|
+
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
1936
|
+
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1854
1937
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1855
1938
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1856
1939
|
}
|
1857
|
-
else if (src2t == GGML_TYPE_IQ4_NL) {
|
1940
|
+
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
1858
1941
|
const int mem_size = 32*sizeof(float);
|
1859
1942
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1860
1943
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1900,8 +1983,11 @@ static bool ggml_metal_graph_compute(
|
|
1900
1983
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
1901
1984
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
1902
1985
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
1986
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
1987
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
1903
1988
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
1904
1989
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
1990
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
1905
1991
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
1906
1992
|
default: GGML_ASSERT(false && "not implemented");
|
1907
1993
|
}
|
@@ -2011,6 +2097,7 @@ static bool ggml_metal_graph_compute(
|
|
2011
2097
|
|
2012
2098
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2013
2099
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
2100
|
+
|
2014
2101
|
float max_bias;
|
2015
2102
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2016
2103
|
|
@@ -2225,6 +2312,50 @@ static bool ggml_metal_graph_compute(
|
|
2225
2312
|
|
2226
2313
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2314
|
} break;
|
2315
|
+
case GGML_OP_ARANGE:
|
2316
|
+
{
|
2317
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
2318
|
+
|
2319
|
+
float start;
|
2320
|
+
float step;
|
2321
|
+
|
2322
|
+
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
2323
|
+
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
2324
|
+
|
2325
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
2326
|
+
|
2327
|
+
[encoder setComputePipelineState:pipeline];
|
2328
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
2329
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
2330
|
+
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
2331
|
+
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
2332
|
+
|
2333
|
+
const int nth = MIN(1024, ne0);
|
2334
|
+
|
2335
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2336
|
+
} break;
|
2337
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
2338
|
+
{
|
2339
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2340
|
+
|
2341
|
+
const int dim = dst->op_params[0];
|
2342
|
+
const int max_period = dst->op_params[1];
|
2343
|
+
|
2344
|
+
const int half = dim / 2;
|
2345
|
+
|
2346
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
2347
|
+
|
2348
|
+
[encoder setComputePipelineState:pipeline];
|
2349
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2351
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
2352
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
2353
|
+
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
2354
|
+
|
2355
|
+
const int nth = MIN(1024, half);
|
2356
|
+
|
2357
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2358
|
+
} break;
|
2228
2359
|
case GGML_OP_ARGSORT:
|
2229
2360
|
{
|
2230
2361
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
@@ -2237,8 +2368,8 @@ static bool ggml_metal_graph_compute(
|
|
2237
2368
|
id<MTLComputePipelineState> pipeline = nil;
|
2238
2369
|
|
2239
2370
|
switch (order) {
|
2240
|
-
case
|
2241
|
-
case
|
2371
|
+
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
2372
|
+
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
2242
2373
|
default: GGML_ASSERT(false);
|
2243
2374
|
};
|
2244
2375
|
|
@@ -2353,7 +2484,7 @@ static bool ggml_metal_graph_compute(
|
|
2353
2484
|
MTLCommandBufferStatus status = [command_buffer status];
|
2354
2485
|
if (status != MTLCommandBufferStatusCompleted) {
|
2355
2486
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2356
|
-
return
|
2487
|
+
return GGML_STATUS_FAILED;
|
2357
2488
|
}
|
2358
2489
|
}
|
2359
2490
|
|
@@ -2362,7 +2493,7 @@ static bool ggml_metal_graph_compute(
|
|
2362
2493
|
}
|
2363
2494
|
|
2364
2495
|
}
|
2365
|
-
return
|
2496
|
+
return GGML_STATUS_SUCCESS;
|
2366
2497
|
}
|
2367
2498
|
|
2368
2499
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -2664,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
|
|
2664
2795
|
UNUSED(backend);
|
2665
2796
|
}
|
2666
2797
|
|
2667
|
-
GGML_CALL static
|
2798
|
+
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
2668
2799
|
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
2669
2800
|
|
2670
2801
|
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
@@ -2696,6 +2827,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
|
|
2696
2827
|
ggml_metal_log_user_data = user_data;
|
2697
2828
|
}
|
2698
2829
|
|
2830
|
+
static ggml_guid_t ggml_backend_metal_guid(void) {
|
2831
|
+
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
2832
|
+
return &guid;
|
2833
|
+
}
|
2834
|
+
|
2699
2835
|
ggml_backend_t ggml_backend_metal_init(void) {
|
2700
2836
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2701
2837
|
|
@@ -2706,6 +2842,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2706
2842
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
2707
2843
|
|
2708
2844
|
*metal_backend = (struct ggml_backend) {
|
2845
|
+
/* .guid = */ ggml_backend_metal_guid(),
|
2709
2846
|
/* .interface = */ ggml_backend_metal_i,
|
2710
2847
|
/* .context = */ ctx,
|
2711
2848
|
};
|
@@ -2714,7 +2851,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
2714
2851
|
}
|
2715
2852
|
|
2716
2853
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
2717
|
-
return backend && backend->
|
2854
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
2718
2855
|
}
|
2719
2856
|
|
2720
2857
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|