llama_cpp 0.12.6 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/ext/llama_cpp/llama_cpp.cpp +90 -269
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +28 -23
- data/vendor/tmp/llama.cpp/Makefile +51 -15
- data/vendor/tmp/llama.cpp/ggml-alloc.c +73 -43
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -11
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +560 -346
- data/vendor/tmp/llama.cpp/ggml-impl.h +20 -7
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +191 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +2472 -862
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +25 -25
- data/vendor/tmp/llama.cpp/ggml-quants.c +3176 -667
- data/vendor/tmp/llama.cpp/ggml-quants.h +77 -2
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +373 -424
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +186 -102
- data/vendor/tmp/llama.cpp/ggml.c +1266 -699
- data/vendor/tmp/llama.cpp/ggml.h +59 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1517 -717
- data/vendor/tmp/llama.cpp/llama.h +87 -63
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +1 -1
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
|
@@ -53,11 +53,23 @@ extern "C" {
|
|
|
53
53
|
//
|
|
54
54
|
#include <arm_neon.h>
|
|
55
55
|
|
|
56
|
-
#define GGML_COMPUTE_FP16_TO_FP32(x) (
|
|
57
|
-
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
|
|
56
|
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
|
57
|
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
|
58
|
+
|
|
59
|
+
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
|
60
|
+
|
|
61
|
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
|
62
|
+
__fp16 tmp;
|
|
63
|
+
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
|
|
64
|
+
return (float)tmp;
|
|
65
|
+
}
|
|
58
66
|
|
|
59
|
-
|
|
60
|
-
|
|
67
|
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
68
|
+
ggml_fp16_t res;
|
|
69
|
+
__fp16 tmp = f;
|
|
70
|
+
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
|
|
71
|
+
return res;
|
|
72
|
+
}
|
|
61
73
|
|
|
62
74
|
#else
|
|
63
75
|
|
|
@@ -214,8 +226,7 @@ extern float ggml_table_f32_f16[1 << 16];
|
|
|
214
226
|
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
|
215
227
|
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
|
216
228
|
// This is also true for POWER9.
|
|
217
|
-
#if !defined(GGML_FP16_TO_FP32)
|
|
218
|
-
|
|
229
|
+
#if !defined(GGML_FP16_TO_FP32)
|
|
219
230
|
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|
220
231
|
uint16_t s;
|
|
221
232
|
memcpy(&s, &f, sizeof(uint16_t));
|
|
@@ -223,8 +234,10 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|
|
223
234
|
}
|
|
224
235
|
|
|
225
236
|
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
|
|
226
|
-
#
|
|
237
|
+
#endif
|
|
227
238
|
|
|
239
|
+
#if !defined(GGML_FP32_TO_FP16)
|
|
240
|
+
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
|
228
241
|
#endif
|
|
229
242
|
|
|
230
243
|
#define GGML_HASHTABLE_FULL ((size_t)-1)
|
|
@@ -1953,11 +1953,17 @@ static struct ggml_backend_i kompute_backend_i = {
|
|
|
1953
1953
|
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
|
1954
1954
|
};
|
|
1955
1955
|
|
|
1956
|
+
static ggml_guid_t ggml_backend_kompute_guid() {
|
|
1957
|
+
static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
|
|
1958
|
+
return &guid;
|
|
1959
|
+
}
|
|
1960
|
+
|
|
1956
1961
|
ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
1957
1962
|
GGML_ASSERT(s_kompute_context == nullptr);
|
|
1958
1963
|
s_kompute_context = new ggml_kompute_context(device);
|
|
1959
1964
|
|
|
1960
1965
|
ggml_backend_t kompute_backend = new ggml_backend {
|
|
1966
|
+
/* .guid = */ ggml_backend_kompute_guid(),
|
|
1961
1967
|
/* .interface = */ kompute_backend_i,
|
|
1962
1968
|
/* .context = */ s_kompute_context,
|
|
1963
1969
|
};
|
|
@@ -1966,7 +1972,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
|
1966
1972
|
}
|
|
1967
1973
|
|
|
1968
1974
|
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
|
1969
|
-
return backend && backend->
|
|
1975
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
|
1970
1976
|
}
|
|
1971
1977
|
|
|
1972
1978
|
static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
|
|
@@ -61,6 +61,11 @@ enum ggml_metal_kernel_type {
|
|
|
61
61
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
|
62
62
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
|
63
63
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
|
64
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
|
65
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
|
66
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
|
67
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
68
|
+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
64
69
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
65
70
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
66
71
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -83,6 +88,11 @@ enum ggml_metal_kernel_type {
|
|
|
83
88
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
|
|
84
89
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
|
|
85
90
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
|
|
91
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
|
92
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
|
93
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
|
94
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
|
95
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
|
86
96
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
87
97
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
88
98
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
@@ -101,6 +111,11 @@ enum ggml_metal_kernel_type {
|
|
|
101
111
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
|
|
102
112
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
|
|
103
113
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
|
|
114
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
|
115
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
|
116
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
|
117
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
|
118
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
|
104
119
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
105
120
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
106
121
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
@@ -116,6 +131,11 @@ enum ggml_metal_kernel_type {
|
|
|
116
131
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
|
117
132
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
|
118
133
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
|
134
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
|
135
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
|
136
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
|
137
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
|
138
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
|
119
139
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
120
140
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
121
141
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
@@ -131,6 +151,11 @@ enum ggml_metal_kernel_type {
|
|
|
131
151
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
|
132
152
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
|
133
153
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
|
154
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
|
155
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
|
156
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
|
157
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
|
158
|
+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
|
134
159
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
135
160
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
136
161
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
@@ -176,7 +201,7 @@ struct ggml_metal_context {
|
|
|
176
201
|
// MSL code
|
|
177
202
|
// TODO: move the contents here when ready
|
|
178
203
|
// for now it is easier to work in a separate file
|
|
179
|
-
//static NSString * const msl_library_source = @"see metal.metal";
|
|
204
|
+
// static NSString * const msl_library_source = @"see metal.metal";
|
|
180
205
|
|
|
181
206
|
// Here to assist with NSBundle Path Hack
|
|
182
207
|
@interface GGMLMetalClass : NSObject
|
|
@@ -272,6 +297,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
272
297
|
return NULL;
|
|
273
298
|
}
|
|
274
299
|
} else {
|
|
300
|
+
#if GGML_METAL_EMBED_LIBRARY
|
|
301
|
+
GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
|
|
302
|
+
|
|
303
|
+
extern const char ggml_metallib_start[];
|
|
304
|
+
extern const char ggml_metallib_end[];
|
|
305
|
+
|
|
306
|
+
NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
|
307
|
+
#else
|
|
275
308
|
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
|
276
309
|
|
|
277
310
|
NSString * sourcePath;
|
|
@@ -294,6 +327,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
294
327
|
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
295
328
|
return NULL;
|
|
296
329
|
}
|
|
330
|
+
#endif
|
|
297
331
|
|
|
298
332
|
@autoreleasepool {
|
|
299
333
|
// dictionary of preprocessor macros
|
|
@@ -433,6 +467,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
433
467
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
|
434
468
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
|
435
469
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
|
470
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
|
471
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
|
472
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
|
473
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
474
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
436
475
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
437
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
438
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
@@ -455,6 +494,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
455
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
456
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
457
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
501
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
458
502
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
459
503
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
460
504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
@@ -473,6 +517,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
473
517
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
|
474
518
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
|
475
519
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
|
520
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
|
521
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
522
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
523
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
524
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
476
525
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
477
526
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
478
527
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -488,6 +537,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
488
537
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
489
538
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
490
539
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
540
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
|
541
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
542
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
543
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
544
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
491
545
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
492
546
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
493
547
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -503,6 +557,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
503
557
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
|
504
558
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
|
505
559
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
|
560
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
|
561
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
562
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
563
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
564
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
506
565
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
507
566
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
508
567
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
@@ -728,6 +787,7 @@ static bool ggml_metal_graph_compute(
|
|
|
728
787
|
|
|
729
788
|
size_t offs_src0 = 0;
|
|
730
789
|
size_t offs_src1 = 0;
|
|
790
|
+
size_t offs_src2 = 0;
|
|
731
791
|
size_t offs_dst = 0;
|
|
732
792
|
|
|
733
793
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
@@ -746,6 +806,7 @@ static bool ggml_metal_graph_compute(
|
|
|
746
806
|
|
|
747
807
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
748
808
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
809
|
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
749
810
|
struct ggml_tensor * dst = gf->nodes[i];
|
|
750
811
|
|
|
751
812
|
switch (dst->op) {
|
|
@@ -807,6 +868,7 @@ static bool ggml_metal_graph_compute(
|
|
|
807
868
|
|
|
808
869
|
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
|
809
870
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
|
871
|
+
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
810
872
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
811
873
|
|
|
812
874
|
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
@@ -1188,7 +1250,16 @@ static bool ggml_metal_graph_compute(
|
|
|
1188
1250
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
|
1189
1251
|
}
|
|
1190
1252
|
|
|
1191
|
-
const float scale
|
|
1253
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
1254
|
+
const float max_bias = ((float *) dst->op_params)[1];
|
|
1255
|
+
|
|
1256
|
+
const int64_t nrows_x = ggml_nrows(src0);
|
|
1257
|
+
const int64_t nrows_y = src0->ne[1];
|
|
1258
|
+
const uint32_t n_head_kv = nrows_x/nrows_y;
|
|
1259
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
1260
|
+
|
|
1261
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
1262
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1192
1263
|
|
|
1193
1264
|
[encoder setComputePipelineState:pipeline];
|
|
1194
1265
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1197,11 +1268,20 @@ static bool ggml_metal_graph_compute(
|
|
|
1197
1268
|
} else {
|
|
1198
1269
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
1199
1270
|
}
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1271
|
+
if (id_src2) {
|
|
1272
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
|
1273
|
+
} else {
|
|
1274
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
|
1275
|
+
}
|
|
1276
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
1277
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
|
1278
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
|
1279
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
|
1280
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
|
1281
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
|
1282
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
|
1283
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
|
1284
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
|
1205
1285
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1206
1286
|
|
|
1207
1287
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
@@ -1297,6 +1377,11 @@ static bool ggml_metal_graph_compute(
|
|
|
1297
1377
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
|
1298
1378
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
|
1299
1379
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
|
1380
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
|
1381
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
|
1382
|
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
|
1383
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
|
1384
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
|
1300
1385
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
1301
1386
|
}
|
|
1302
1387
|
|
|
@@ -1431,6 +1516,36 @@ static bool ggml_metal_graph_compute(
|
|
|
1431
1516
|
nth1 = 16;
|
|
1432
1517
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
|
1433
1518
|
} break;
|
|
1519
|
+
case GGML_TYPE_IQ3_S:
|
|
1520
|
+
{
|
|
1521
|
+
nth0 = 4;
|
|
1522
|
+
nth1 = 16;
|
|
1523
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
|
1524
|
+
} break;
|
|
1525
|
+
case GGML_TYPE_IQ2_S:
|
|
1526
|
+
{
|
|
1527
|
+
nth0 = 4;
|
|
1528
|
+
nth1 = 16;
|
|
1529
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
|
1530
|
+
} break;
|
|
1531
|
+
case GGML_TYPE_IQ1_S:
|
|
1532
|
+
{
|
|
1533
|
+
nth0 = 4;
|
|
1534
|
+
nth1 = 16;
|
|
1535
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
|
1536
|
+
} break;
|
|
1537
|
+
case GGML_TYPE_IQ4_NL:
|
|
1538
|
+
{
|
|
1539
|
+
nth0 = 4;
|
|
1540
|
+
nth1 = 16;
|
|
1541
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
1542
|
+
} break;
|
|
1543
|
+
case GGML_TYPE_IQ4_XS:
|
|
1544
|
+
{
|
|
1545
|
+
nth0 = 4;
|
|
1546
|
+
nth1 = 16;
|
|
1547
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
|
1548
|
+
} break;
|
|
1434
1549
|
default:
|
|
1435
1550
|
{
|
|
1436
1551
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1463,9 +1578,9 @@ static bool ggml_metal_graph_compute(
|
|
|
1463
1578
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1464
1579
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1465
1580
|
|
|
1466
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1
|
|
1467
|
-
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1
|
|
1468
|
-
src0t == GGML_TYPE_Q2_K
|
|
1581
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
|
1582
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
|
1583
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
|
1469
1584
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1470
1585
|
}
|
|
1471
1586
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
|
@@ -1473,11 +1588,16 @@ static bool ggml_metal_graph_compute(
|
|
|
1473
1588
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1474
1589
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1475
1590
|
}
|
|
1476
|
-
else if (src0t == GGML_TYPE_IQ3_XXS) {
|
|
1477
|
-
const int mem_size = 256*4+128;
|
|
1591
|
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
|
1592
|
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
|
1478
1593
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1479
1594
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1480
1595
|
}
|
|
1596
|
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
|
1597
|
+
const int mem_size = 32*sizeof(float);
|
|
1598
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1599
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1600
|
+
}
|
|
1481
1601
|
else if (src0t == GGML_TYPE_Q4_K) {
|
|
1482
1602
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1483
1603
|
}
|
|
@@ -1514,8 +1634,6 @@ static bool ggml_metal_graph_compute(
|
|
|
1514
1634
|
// max size of the src1ids array in the kernel stack
|
|
1515
1635
|
GGML_ASSERT(ne11 <= 512);
|
|
1516
1636
|
|
|
1517
|
-
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1518
|
-
|
|
1519
1637
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1520
1638
|
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1521
1639
|
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
@@ -1573,6 +1691,11 @@ static bool ggml_metal_graph_compute(
|
|
|
1573
1691
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
|
1574
1692
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
|
1575
1693
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
|
1694
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
|
1695
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
|
1696
|
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
|
1697
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
|
1698
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
|
1576
1699
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1577
1700
|
}
|
|
1578
1701
|
|
|
@@ -1710,6 +1833,36 @@ static bool ggml_metal_graph_compute(
|
|
|
1710
1833
|
nth1 = 16;
|
|
1711
1834
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
|
1712
1835
|
} break;
|
|
1836
|
+
case GGML_TYPE_IQ3_S:
|
|
1837
|
+
{
|
|
1838
|
+
nth0 = 4;
|
|
1839
|
+
nth1 = 16;
|
|
1840
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
|
1841
|
+
} break;
|
|
1842
|
+
case GGML_TYPE_IQ2_S:
|
|
1843
|
+
{
|
|
1844
|
+
nth0 = 4;
|
|
1845
|
+
nth1 = 16;
|
|
1846
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
|
1847
|
+
} break;
|
|
1848
|
+
case GGML_TYPE_IQ1_S:
|
|
1849
|
+
{
|
|
1850
|
+
nth0 = 4;
|
|
1851
|
+
nth1 = 16;
|
|
1852
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
|
1853
|
+
} break;
|
|
1854
|
+
case GGML_TYPE_IQ4_NL:
|
|
1855
|
+
{
|
|
1856
|
+
nth0 = 4;
|
|
1857
|
+
nth1 = 16;
|
|
1858
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
|
1859
|
+
} break;
|
|
1860
|
+
case GGML_TYPE_IQ4_XS:
|
|
1861
|
+
{
|
|
1862
|
+
nth0 = 4;
|
|
1863
|
+
nth1 = 16;
|
|
1864
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
|
1865
|
+
} break;
|
|
1713
1866
|
default:
|
|
1714
1867
|
{
|
|
1715
1868
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1758,9 +1911,9 @@ static bool ggml_metal_graph_compute(
|
|
|
1758
1911
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1759
1912
|
}
|
|
1760
1913
|
|
|
1761
|
-
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1
|
|
1762
|
-
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1
|
|
1763
|
-
src2t == GGML_TYPE_Q2_K
|
|
1914
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
|
1915
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
|
1916
|
+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
|
1764
1917
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1765
1918
|
}
|
|
1766
1919
|
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
|
@@ -1768,11 +1921,16 @@ static bool ggml_metal_graph_compute(
|
|
|
1768
1921
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1769
1922
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1770
1923
|
}
|
|
1771
|
-
else if (src2t == GGML_TYPE_IQ3_XXS) {
|
|
1772
|
-
const int mem_size = 256*4+128;
|
|
1924
|
+
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
|
1925
|
+
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
|
1773
1926
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1774
1927
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1775
1928
|
}
|
|
1929
|
+
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
|
1930
|
+
const int mem_size = 32*sizeof(float);
|
|
1931
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
1932
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1933
|
+
}
|
|
1776
1934
|
else if (src2t == GGML_TYPE_Q4_K) {
|
|
1777
1935
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1778
1936
|
}
|
|
@@ -1814,6 +1972,11 @@ static bool ggml_metal_graph_compute(
|
|
|
1814
1972
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
|
|
1815
1973
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
|
|
1816
1974
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
|
|
1975
|
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
|
1976
|
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
|
1977
|
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
|
1978
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
|
1979
|
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
|
1817
1980
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
1818
1981
|
default: GGML_ASSERT(false && "not implemented");
|
|
1819
1982
|
}
|
|
@@ -2149,8 +2312,8 @@ static bool ggml_metal_graph_compute(
|
|
|
2149
2312
|
id<MTLComputePipelineState> pipeline = nil;
|
|
2150
2313
|
|
|
2151
2314
|
switch (order) {
|
|
2152
|
-
case
|
|
2153
|
-
case
|
|
2315
|
+
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
|
|
2316
|
+
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
|
|
2154
2317
|
default: GGML_ASSERT(false);
|
|
2155
2318
|
};
|
|
2156
2319
|
|
|
@@ -2608,6 +2771,11 @@ void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void *
|
|
|
2608
2771
|
ggml_metal_log_user_data = user_data;
|
|
2609
2772
|
}
|
|
2610
2773
|
|
|
2774
|
+
static ggml_guid_t ggml_backend_metal_guid(void) {
|
|
2775
|
+
static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
|
|
2776
|
+
return &guid;
|
|
2777
|
+
}
|
|
2778
|
+
|
|
2611
2779
|
ggml_backend_t ggml_backend_metal_init(void) {
|
|
2612
2780
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
|
2613
2781
|
|
|
@@ -2618,6 +2786,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
2618
2786
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
|
2619
2787
|
|
|
2620
2788
|
*metal_backend = (struct ggml_backend) {
|
|
2789
|
+
/* .guid = */ ggml_backend_metal_guid(),
|
|
2621
2790
|
/* .interface = */ ggml_backend_metal_i,
|
|
2622
2791
|
/* .context = */ ctx,
|
|
2623
2792
|
};
|
|
@@ -2626,7 +2795,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
|
|
|
2626
2795
|
}
|
|
2627
2796
|
|
|
2628
2797
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
|
2629
|
-
return backend && backend->
|
|
2798
|
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
|
2630
2799
|
}
|
|
2631
2800
|
|
|
2632
2801
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|