llama_cpp 0.15.0 → 0.15.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -7
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +303 -23
- data/vendor/tmp/llama.cpp/ggml-impl.h +84 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +137 -133
- data/vendor/tmp/llama.cpp/ggml-metal.metal +87 -110
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +2220 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +35 -152
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +953 -268
- data/vendor/tmp/llama.cpp/ggml.c +1762 -681
- data/vendor/tmp/llama.cpp/ggml.h +43 -24
- data/vendor/tmp/llama.cpp/llama.cpp +533 -296
- data/vendor/tmp/llama.cpp/llama.h +10 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +56 -21
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -1637
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -11
- data/vendor/tmp/llama.cpp/unicode.cpp +286 -176
- data/vendor/tmp/llama.cpp/unicode.h +44 -10
- metadata +4 -2
@@ -17,6 +17,83 @@
|
|
17
17
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
18
18
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
19
19
|
|
20
|
+
/**
|
21
|
+
* Converts brain16 to float32.
|
22
|
+
*
|
23
|
+
* The bfloat16 floating point format has the following structure:
|
24
|
+
*
|
25
|
+
* ┌sign
|
26
|
+
* │
|
27
|
+
* │ ┌exponent
|
28
|
+
* │ │
|
29
|
+
* │ │ ┌mantissa
|
30
|
+
* │ │ │
|
31
|
+
* │┌──┴───┐┌─┴───┐
|
32
|
+
* 0b0000000000000000 brain16
|
33
|
+
*
|
34
|
+
* Since bf16 has the same number of exponent bits as a 32bit float,
|
35
|
+
* encoding and decoding numbers becomes relatively straightforward.
|
36
|
+
*
|
37
|
+
* ┌sign
|
38
|
+
* │
|
39
|
+
* │ ┌exponent
|
40
|
+
* │ │
|
41
|
+
* │ │ ┌mantissa
|
42
|
+
* │ │ │
|
43
|
+
* │┌──┴───┐┌─┴───────────────────┐
|
44
|
+
* 0b00000000000000000000000000000000 IEEE binary32
|
45
|
+
*
|
46
|
+
* For comparison, the standard fp16 format has fewer exponent bits.
|
47
|
+
*
|
48
|
+
* ┌sign
|
49
|
+
* │
|
50
|
+
* │ ┌exponent
|
51
|
+
* │ │
|
52
|
+
* │ │ ┌mantissa
|
53
|
+
* │ │ │
|
54
|
+
* │┌─┴─┐┌─┴──────┐
|
55
|
+
* 0b0000000000000000 IEEE binary16
|
56
|
+
*
|
57
|
+
* @see IEEE 754-2008
|
58
|
+
*/
|
59
|
+
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
60
|
+
union {
|
61
|
+
float f;
|
62
|
+
uint32_t i;
|
63
|
+
} u;
|
64
|
+
u.i = (uint32_t)h.bits << 16;
|
65
|
+
return u.f;
|
66
|
+
}
|
67
|
+
|
68
|
+
/**
|
69
|
+
* Converts float32 to brain16.
|
70
|
+
*
|
71
|
+
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
|
72
|
+
* Subnormals shall be flushed to zero, and NANs will be quiet.
|
73
|
+
* This code should vectorize nicely if using modern compilers.
|
74
|
+
*/
|
75
|
+
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
76
|
+
ggml_bf16_t h;
|
77
|
+
union {
|
78
|
+
float f;
|
79
|
+
uint32_t i;
|
80
|
+
} u;
|
81
|
+
u.f = s;
|
82
|
+
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
83
|
+
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
84
|
+
return h;
|
85
|
+
}
|
86
|
+
if (!(u.i & 0x7f800000)) { /* subnormal */
|
87
|
+
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
|
88
|
+
return h;
|
89
|
+
}
|
90
|
+
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
91
|
+
return h;
|
92
|
+
}
|
93
|
+
|
94
|
+
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
95
|
+
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
96
|
+
|
20
97
|
#ifdef __cplusplus
|
21
98
|
extern "C" {
|
22
99
|
#endif
|
@@ -43,9 +120,16 @@ extern "C" {
|
|
43
120
|
#ifndef __F16C__
|
44
121
|
#define __F16C__
|
45
122
|
#endif
|
123
|
+
#endif
|
124
|
+
|
125
|
+
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
126
|
+
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
46
127
|
#ifndef __SSE3__
|
47
128
|
#define __SSE3__
|
48
129
|
#endif
|
130
|
+
#ifndef __SSSE3__
|
131
|
+
#define __SSSE3__
|
132
|
+
#endif
|
49
133
|
#endif
|
50
134
|
|
51
135
|
// 16-bit float
|
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1559
1559
|
case GGML_OP_SOFT_MAX:
|
1560
1560
|
{
|
1561
1561
|
float scale;
|
1562
|
-
|
1562
|
+
float max_bias;
|
1563
1563
|
|
1564
|
-
|
1564
|
+
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
1565
|
+
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
1566
|
+
|
1567
|
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
1565
1568
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
1566
1569
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
1567
|
-
|
1570
|
+
|
1571
|
+
#pragma message("TODO: add ALiBi support")
|
1572
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
1573
|
+
GGML_ASSERT(max_bias == 0.0f);
|
1568
1574
|
|
1569
1575
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
1570
1576
|
} break;
|
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -265,11 +265,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
265
265
|
|
266
266
|
static void * ggml_metal_host_malloc(size_t n) {
|
267
267
|
void * data = NULL;
|
268
|
+
|
269
|
+
#if TARGET_OS_OSX
|
270
|
+
kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
|
271
|
+
if (err != KERN_SUCCESS) {
|
272
|
+
GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
|
273
|
+
return NULL;
|
274
|
+
}
|
275
|
+
#else
|
268
276
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
269
277
|
if (result != 0) {
|
270
278
|
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
271
279
|
return NULL;
|
272
280
|
}
|
281
|
+
#endif
|
273
282
|
|
274
283
|
return data;
|
275
284
|
}
|
@@ -485,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
485
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
486
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
487
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
488
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
489
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
490
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -614,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
614
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
615
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
616
626
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
617
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
618
627
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
619
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
620
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -624,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
624
633
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
625
634
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
626
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
627
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
628
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
629
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
630
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
631
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
632
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
633
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
634
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
635
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
636
645
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
637
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -723,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
723
732
|
switch (ggml_get_unary_op(op)) {
|
724
733
|
case GGML_UNARY_OP_TANH:
|
725
734
|
case GGML_UNARY_OP_RELU:
|
735
|
+
case GGML_UNARY_OP_SIGMOID:
|
726
736
|
case GGML_UNARY_OP_GELU:
|
727
737
|
case GGML_UNARY_OP_GELU_QUICK:
|
728
738
|
case GGML_UNARY_OP_SILU:
|
@@ -750,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
750
760
|
case GGML_OP_GROUP_NORM:
|
751
761
|
return ctx->support_simdgroup_reduction;
|
752
762
|
case GGML_OP_NORM:
|
753
|
-
case GGML_OP_ALIBI:
|
754
763
|
case GGML_OP_ROPE:
|
755
764
|
case GGML_OP_IM2COL:
|
756
765
|
return true;
|
@@ -763,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
763
772
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
764
773
|
case GGML_OP_ARGSORT:
|
765
774
|
case GGML_OP_LEAKY_RELU:
|
766
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
767
775
|
return true;
|
776
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
777
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
768
778
|
case GGML_OP_MUL_MAT:
|
769
779
|
case GGML_OP_MUL_MAT_ID:
|
770
780
|
return ctx->support_simdgroup_reduction &&
|
@@ -803,7 +813,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
803
813
|
case GGML_OP_DIAG_MASK_INF:
|
804
814
|
case GGML_OP_GET_ROWS:
|
805
815
|
{
|
806
|
-
return op->ne[3] == 1;
|
816
|
+
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
807
817
|
}
|
808
818
|
default:
|
809
819
|
return false;
|
@@ -1185,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1185
1195
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1186
1196
|
} break;
|
1187
1197
|
case GGML_OP_CLAMP:
|
1188
|
-
|
1189
|
-
|
1198
|
+
{
|
1199
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1190
1200
|
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1201
|
+
float min;
|
1202
|
+
float max;
|
1203
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1204
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1195
1205
|
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1206
|
+
[encoder setComputePipelineState:pipeline];
|
1207
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1208
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1209
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1210
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1201
1211
|
|
1202
|
-
|
1212
|
+
const int64_t n = ggml_nelements(dst);
|
1203
1213
|
|
1204
|
-
|
1205
|
-
|
1214
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1215
|
+
} break;
|
1206
1216
|
case GGML_OP_UNARY:
|
1207
1217
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1208
1218
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1230,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1230
1240
|
|
1231
1241
|
const int64_t n = ggml_nelements(dst);
|
1232
1242
|
|
1243
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1244
|
+
} break;
|
1245
|
+
case GGML_UNARY_OP_SIGMOID:
|
1246
|
+
{
|
1247
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1248
|
+
|
1249
|
+
[encoder setComputePipelineState:pipeline];
|
1250
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1251
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1252
|
+
|
1253
|
+
const int64_t n = ggml_nelements(dst);
|
1254
|
+
|
1233
1255
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1234
1256
|
} break;
|
1235
1257
|
case GGML_UNARY_OP_GELU:
|
@@ -1348,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1348
1370
|
case GGML_OP_SOFT_MAX:
|
1349
1371
|
{
|
1350
1372
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1351
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1352
1373
|
|
1353
1374
|
int nth = 32; // SIMD width
|
1354
1375
|
|
1355
1376
|
id<MTLComputePipelineState> pipeline = nil;
|
1356
1377
|
|
1357
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1378
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1358
1379
|
|
1359
1380
|
if (ne00%4 == 0) {
|
1360
|
-
while (nth < ne00/4 && nth < 256) {
|
1381
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1361
1382
|
nth *= 2;
|
1362
1383
|
}
|
1363
1384
|
if (use_f16) {
|
@@ -1366,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1366
1387
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1367
1388
|
}
|
1368
1389
|
} else {
|
1369
|
-
while (nth < ne00 && nth <
|
1390
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1370
1391
|
nth *= 2;
|
1371
1392
|
}
|
1372
1393
|
if (use_f16) {
|
@@ -1385,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1385
1406
|
const int64_t nrows_x = ggml_nrows(src0);
|
1386
1407
|
const int64_t nrows_y = src0->ne[1];
|
1387
1408
|
|
1388
|
-
const uint32_t
|
1389
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1409
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1410
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1390
1411
|
|
1391
1412
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1392
1413
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1398,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1398
1419
|
} else {
|
1399
1420
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1400
1421
|
}
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
[encoder
|
1407
|
-
[encoder setBytes:&
|
1408
|
-
[encoder setBytes:&
|
1409
|
-
[encoder setBytes:&
|
1410
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1411
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1412
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1413
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1414
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1422
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1423
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1424
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1425
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1426
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1427
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1428
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1429
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1430
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1415
1431
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1416
1432
|
|
1417
1433
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -2216,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2216
2232
|
|
2217
2233
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2218
2234
|
} break;
|
2219
|
-
case GGML_OP_ALIBI:
|
2220
|
-
{
|
2221
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2222
|
-
|
2223
|
-
const int nth = MIN(1024, ne00);
|
2224
|
-
|
2225
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2226
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2227
|
-
|
2228
|
-
float max_bias;
|
2229
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2230
|
-
|
2231
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2232
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2233
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2234
|
-
|
2235
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2236
|
-
|
2237
|
-
[encoder setComputePipelineState:pipeline];
|
2238
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2239
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2240
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2241
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2242
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2243
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2244
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2245
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2246
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2247
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2248
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2249
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2250
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2251
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2252
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2253
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2254
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2255
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2256
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2257
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2258
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2259
|
-
|
2260
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2261
|
-
} break;
|
2262
2235
|
case GGML_OP_ROPE:
|
2263
2236
|
{
|
2264
2237
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2380,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2380
2353
|
{
|
2381
2354
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2382
2355
|
|
2383
|
-
const
|
2356
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2357
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2358
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2359
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2384
2360
|
|
2385
2361
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2386
2362
|
|
@@ -2403,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2403
2379
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2404
2380
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2405
2381
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2406
|
-
[encoder setBytes:&
|
2382
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2383
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2384
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2385
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2407
2386
|
|
2408
2387
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2409
2388
|
|
@@ -2539,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2539
2518
|
} break;
|
2540
2519
|
case GGML_OP_FLASH_ATTN_EXT:
|
2541
2520
|
{
|
2542
|
-
GGML_ASSERT(ne00 % 4
|
2521
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2522
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2523
|
+
|
2543
2524
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2544
2525
|
|
2545
|
-
|
2526
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2546
2527
|
|
2547
|
-
|
2548
|
-
GGML_ASSERT(src3);
|
2528
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2549
2529
|
|
2550
2530
|
size_t offs_src3 = 0;
|
2551
2531
|
|
@@ -2555,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2555
2535
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2556
2536
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2557
2537
|
|
2538
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
+
|
2558
2543
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2559
|
-
|
2544
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2560
2545
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2561
2546
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2562
2547
|
|
@@ -2568,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2568
2553
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2569
2554
|
|
2570
2555
|
float scale;
|
2571
|
-
|
2556
|
+
float max_bias;
|
2557
|
+
|
2558
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2559
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2560
|
+
|
2561
|
+
const uint32_t n_head = src0->ne[2];
|
2562
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2563
|
+
|
2564
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2565
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2572
2566
|
|
2573
2567
|
id<MTLComputePipelineState> pipeline = nil;
|
2574
2568
|
|
@@ -2605,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2605
2599
|
}
|
2606
2600
|
|
2607
2601
|
[encoder setComputePipelineState:pipeline];
|
2608
|
-
[encoder setBuffer:id_src0
|
2609
|
-
[encoder setBuffer:id_src1
|
2610
|
-
[encoder setBuffer:id_src2
|
2611
|
-
|
2612
|
-
|
2613
|
-
|
2614
|
-
|
2615
|
-
|
2616
|
-
[encoder
|
2617
|
-
[encoder setBytes:&
|
2618
|
-
[encoder setBytes:&
|
2619
|
-
[encoder setBytes:&
|
2620
|
-
[encoder setBytes:&
|
2621
|
-
[encoder setBytes:&
|
2622
|
-
[encoder setBytes:&
|
2623
|
-
[encoder setBytes:&
|
2624
|
-
[encoder setBytes:&
|
2625
|
-
[encoder setBytes:&
|
2626
|
-
[encoder setBytes:&nb11
|
2627
|
-
[encoder setBytes:&nb12
|
2628
|
-
[encoder setBytes:&nb13
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&scale
|
2602
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2603
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2604
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2605
|
+
if (id_src3) {
|
2606
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2607
|
+
} else {
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2609
|
+
}
|
2610
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2611
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2612
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2613
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2614
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2615
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2616
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2617
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2618
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2619
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2620
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2621
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2622
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2623
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2624
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2625
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2626
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2627
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2628
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2629
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2630
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2631
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2632
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2633
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2636
2634
|
|
2637
2635
|
if (!use_vec_kernel) {
|
2638
2636
|
// half8x8 kernel
|
@@ -2840,7 +2838,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_
|
|
2840
2838
|
ggml_backend_metal_free_device();
|
2841
2839
|
|
2842
2840
|
if (ctx->owned) {
|
2841
|
+
#if TARGET_OS_OSX
|
2842
|
+
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
|
2843
|
+
#else
|
2843
2844
|
free(ctx->all_data);
|
2845
|
+
#endif
|
2844
2846
|
}
|
2845
2847
|
|
2846
2848
|
free(ctx);
|
@@ -2944,14 +2946,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|
2944
2946
|
ctx->owned = true;
|
2945
2947
|
ctx->n_buffers = 1;
|
2946
2948
|
|
2947
|
-
ctx->
|
2948
|
-
|
2949
|
-
|
2950
|
-
|
2951
|
-
|
2952
|
-
|
2949
|
+
if (ctx->all_data != NULL) {
|
2950
|
+
ctx->buffers[0].data = ctx->all_data;
|
2951
|
+
ctx->buffers[0].size = size;
|
2952
|
+
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
2953
|
+
length:size_aligned
|
2954
|
+
options:MTLResourceStorageModeShared
|
2955
|
+
deallocator:nil];
|
2956
|
+
}
|
2953
2957
|
|
2954
|
-
if (ctx->buffers[0].metal == nil) {
|
2958
|
+
if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
|
2955
2959
|
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2956
2960
|
free(ctx);
|
2957
2961
|
ggml_backend_metal_free_device();
|