llama_cpp 0.15.0 → 0.15.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -7
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +303 -23
- data/vendor/tmp/llama.cpp/ggml-impl.h +84 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +137 -133
- data/vendor/tmp/llama.cpp/ggml-metal.metal +87 -110
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +1 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +2220 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +35 -152
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +46843 -39205
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +953 -268
- data/vendor/tmp/llama.cpp/ggml.c +1762 -681
- data/vendor/tmp/llama.cpp/ggml.h +43 -24
- data/vendor/tmp/llama.cpp/llama.cpp +533 -296
- data/vendor/tmp/llama.cpp/llama.h +10 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +56 -21
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -1637
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -11
- data/vendor/tmp/llama.cpp/unicode.cpp +286 -176
- data/vendor/tmp/llama.cpp/unicode.h +44 -10
- metadata +4 -2
@@ -17,6 +17,83 @@
|
|
17
17
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
18
18
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
19
19
|
|
20
|
+
/**
|
21
|
+
* Converts brain16 to float32.
|
22
|
+
*
|
23
|
+
* The bfloat16 floating point format has the following structure:
|
24
|
+
*
|
25
|
+
* ┌sign
|
26
|
+
* │
|
27
|
+
* │ ┌exponent
|
28
|
+
* │ │
|
29
|
+
* │ │ ┌mantissa
|
30
|
+
* │ │ │
|
31
|
+
* │┌──┴───┐┌─┴───┐
|
32
|
+
* 0b0000000000000000 brain16
|
33
|
+
*
|
34
|
+
* Since bf16 has the same number of exponent bits as a 32bit float,
|
35
|
+
* encoding and decoding numbers becomes relatively straightforward.
|
36
|
+
*
|
37
|
+
* ┌sign
|
38
|
+
* │
|
39
|
+
* │ ┌exponent
|
40
|
+
* │ │
|
41
|
+
* │ │ ┌mantissa
|
42
|
+
* │ │ │
|
43
|
+
* │┌──┴───┐┌─┴───────────────────┐
|
44
|
+
* 0b00000000000000000000000000000000 IEEE binary32
|
45
|
+
*
|
46
|
+
* For comparison, the standard fp16 format has fewer exponent bits.
|
47
|
+
*
|
48
|
+
* ┌sign
|
49
|
+
* │
|
50
|
+
* │ ┌exponent
|
51
|
+
* │ │
|
52
|
+
* │ │ ┌mantissa
|
53
|
+
* │ │ │
|
54
|
+
* │┌─┴─┐┌─┴──────┐
|
55
|
+
* 0b0000000000000000 IEEE binary16
|
56
|
+
*
|
57
|
+
* @see IEEE 754-2008
|
58
|
+
*/
|
59
|
+
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
60
|
+
union {
|
61
|
+
float f;
|
62
|
+
uint32_t i;
|
63
|
+
} u;
|
64
|
+
u.i = (uint32_t)h.bits << 16;
|
65
|
+
return u.f;
|
66
|
+
}
|
67
|
+
|
68
|
+
/**
|
69
|
+
* Converts float32 to brain16.
|
70
|
+
*
|
71
|
+
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
|
72
|
+
* Subnormals shall be flushed to zero, and NANs will be quiet.
|
73
|
+
* This code should vectorize nicely if using modern compilers.
|
74
|
+
*/
|
75
|
+
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
76
|
+
ggml_bf16_t h;
|
77
|
+
union {
|
78
|
+
float f;
|
79
|
+
uint32_t i;
|
80
|
+
} u;
|
81
|
+
u.f = s;
|
82
|
+
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
83
|
+
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
84
|
+
return h;
|
85
|
+
}
|
86
|
+
if (!(u.i & 0x7f800000)) { /* subnormal */
|
87
|
+
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
|
88
|
+
return h;
|
89
|
+
}
|
90
|
+
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
91
|
+
return h;
|
92
|
+
}
|
93
|
+
|
94
|
+
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
95
|
+
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
96
|
+
|
20
97
|
#ifdef __cplusplus
|
21
98
|
extern "C" {
|
22
99
|
#endif
|
@@ -43,9 +120,16 @@ extern "C" {
|
|
43
120
|
#ifndef __F16C__
|
44
121
|
#define __F16C__
|
45
122
|
#endif
|
123
|
+
#endif
|
124
|
+
|
125
|
+
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
126
|
+
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
46
127
|
#ifndef __SSE3__
|
47
128
|
#define __SSE3__
|
48
129
|
#endif
|
130
|
+
#ifndef __SSSE3__
|
131
|
+
#define __SSSE3__
|
132
|
+
#endif
|
49
133
|
#endif
|
50
134
|
|
51
135
|
// 16-bit float
|
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1559
1559
|
case GGML_OP_SOFT_MAX:
|
1560
1560
|
{
|
1561
1561
|
float scale;
|
1562
|
-
|
1562
|
+
float max_bias;
|
1563
1563
|
|
1564
|
-
|
1564
|
+
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
1565
|
+
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
1566
|
+
|
1567
|
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
1565
1568
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
1566
1569
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
1567
|
-
|
1570
|
+
|
1571
|
+
#pragma message("TODO: add ALiBi support")
|
1572
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
1573
|
+
GGML_ASSERT(max_bias == 0.0f);
|
1568
1574
|
|
1569
1575
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
1570
1576
|
} break;
|
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -265,11 +265,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
265
265
|
|
266
266
|
static void * ggml_metal_host_malloc(size_t n) {
|
267
267
|
void * data = NULL;
|
268
|
+
|
269
|
+
#if TARGET_OS_OSX
|
270
|
+
kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
|
271
|
+
if (err != KERN_SUCCESS) {
|
272
|
+
GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
|
273
|
+
return NULL;
|
274
|
+
}
|
275
|
+
#else
|
268
276
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
269
277
|
if (result != 0) {
|
270
278
|
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
271
279
|
return NULL;
|
272
280
|
}
|
281
|
+
#endif
|
273
282
|
|
274
283
|
return data;
|
275
284
|
}
|
@@ -485,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
485
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
486
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
487
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
488
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
489
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
490
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -614,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
614
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
615
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
616
626
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
617
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
618
627
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
619
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
620
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -624,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
624
633
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
625
634
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
626
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
627
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
628
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
629
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
630
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
631
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
632
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
633
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
634
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
635
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
636
645
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
637
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -723,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
723
732
|
switch (ggml_get_unary_op(op)) {
|
724
733
|
case GGML_UNARY_OP_TANH:
|
725
734
|
case GGML_UNARY_OP_RELU:
|
735
|
+
case GGML_UNARY_OP_SIGMOID:
|
726
736
|
case GGML_UNARY_OP_GELU:
|
727
737
|
case GGML_UNARY_OP_GELU_QUICK:
|
728
738
|
case GGML_UNARY_OP_SILU:
|
@@ -750,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
750
760
|
case GGML_OP_GROUP_NORM:
|
751
761
|
return ctx->support_simdgroup_reduction;
|
752
762
|
case GGML_OP_NORM:
|
753
|
-
case GGML_OP_ALIBI:
|
754
763
|
case GGML_OP_ROPE:
|
755
764
|
case GGML_OP_IM2COL:
|
756
765
|
return true;
|
@@ -763,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
763
772
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
764
773
|
case GGML_OP_ARGSORT:
|
765
774
|
case GGML_OP_LEAKY_RELU:
|
766
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
767
775
|
return true;
|
776
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
777
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
768
778
|
case GGML_OP_MUL_MAT:
|
769
779
|
case GGML_OP_MUL_MAT_ID:
|
770
780
|
return ctx->support_simdgroup_reduction &&
|
@@ -803,7 +813,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
803
813
|
case GGML_OP_DIAG_MASK_INF:
|
804
814
|
case GGML_OP_GET_ROWS:
|
805
815
|
{
|
806
|
-
return op->ne[3] == 1;
|
816
|
+
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
807
817
|
}
|
808
818
|
default:
|
809
819
|
return false;
|
@@ -1185,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1185
1195
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1186
1196
|
} break;
|
1187
1197
|
case GGML_OP_CLAMP:
|
1188
|
-
|
1189
|
-
|
1198
|
+
{
|
1199
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1190
1200
|
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1201
|
+
float min;
|
1202
|
+
float max;
|
1203
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1204
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1195
1205
|
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1206
|
+
[encoder setComputePipelineState:pipeline];
|
1207
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1208
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1209
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1210
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1201
1211
|
|
1202
|
-
|
1212
|
+
const int64_t n = ggml_nelements(dst);
|
1203
1213
|
|
1204
|
-
|
1205
|
-
|
1214
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1215
|
+
} break;
|
1206
1216
|
case GGML_OP_UNARY:
|
1207
1217
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1208
1218
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1230,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1230
1240
|
|
1231
1241
|
const int64_t n = ggml_nelements(dst);
|
1232
1242
|
|
1243
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1244
|
+
} break;
|
1245
|
+
case GGML_UNARY_OP_SIGMOID:
|
1246
|
+
{
|
1247
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1248
|
+
|
1249
|
+
[encoder setComputePipelineState:pipeline];
|
1250
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1251
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1252
|
+
|
1253
|
+
const int64_t n = ggml_nelements(dst);
|
1254
|
+
|
1233
1255
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1234
1256
|
} break;
|
1235
1257
|
case GGML_UNARY_OP_GELU:
|
@@ -1348,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1348
1370
|
case GGML_OP_SOFT_MAX:
|
1349
1371
|
{
|
1350
1372
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1351
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1352
1373
|
|
1353
1374
|
int nth = 32; // SIMD width
|
1354
1375
|
|
1355
1376
|
id<MTLComputePipelineState> pipeline = nil;
|
1356
1377
|
|
1357
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1378
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1358
1379
|
|
1359
1380
|
if (ne00%4 == 0) {
|
1360
|
-
while (nth < ne00/4 && nth < 256) {
|
1381
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1361
1382
|
nth *= 2;
|
1362
1383
|
}
|
1363
1384
|
if (use_f16) {
|
@@ -1366,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1366
1387
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1367
1388
|
}
|
1368
1389
|
} else {
|
1369
|
-
while (nth < ne00 && nth <
|
1390
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1370
1391
|
nth *= 2;
|
1371
1392
|
}
|
1372
1393
|
if (use_f16) {
|
@@ -1385,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1385
1406
|
const int64_t nrows_x = ggml_nrows(src0);
|
1386
1407
|
const int64_t nrows_y = src0->ne[1];
|
1387
1408
|
|
1388
|
-
const uint32_t
|
1389
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1409
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1410
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1390
1411
|
|
1391
1412
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1392
1413
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1398,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1398
1419
|
} else {
|
1399
1420
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1400
1421
|
}
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
[encoder
|
1407
|
-
[encoder setBytes:&
|
1408
|
-
[encoder setBytes:&
|
1409
|
-
[encoder setBytes:&
|
1410
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1411
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1412
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1413
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1414
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1422
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1423
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1424
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1425
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1426
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1427
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1428
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1429
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1430
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1415
1431
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1416
1432
|
|
1417
1433
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -2216,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2216
2232
|
|
2217
2233
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2218
2234
|
} break;
|
2219
|
-
case GGML_OP_ALIBI:
|
2220
|
-
{
|
2221
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2222
|
-
|
2223
|
-
const int nth = MIN(1024, ne00);
|
2224
|
-
|
2225
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2226
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2227
|
-
|
2228
|
-
float max_bias;
|
2229
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2230
|
-
|
2231
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2232
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2233
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2234
|
-
|
2235
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2236
|
-
|
2237
|
-
[encoder setComputePipelineState:pipeline];
|
2238
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2239
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2240
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2241
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2242
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2243
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2244
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2245
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2246
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2247
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2248
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2249
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2250
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2251
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2252
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2253
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2254
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2255
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2256
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2257
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2258
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2259
|
-
|
2260
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2261
|
-
} break;
|
2262
2235
|
case GGML_OP_ROPE:
|
2263
2236
|
{
|
2264
2237
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2380,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2380
2353
|
{
|
2381
2354
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2382
2355
|
|
2383
|
-
const
|
2356
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2357
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2358
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2359
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2384
2360
|
|
2385
2361
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2386
2362
|
|
@@ -2403,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2403
2379
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2404
2380
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2405
2381
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2406
|
-
[encoder setBytes:&
|
2382
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2383
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2384
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2385
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2407
2386
|
|
2408
2387
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2409
2388
|
|
@@ -2539,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2539
2518
|
} break;
|
2540
2519
|
case GGML_OP_FLASH_ATTN_EXT:
|
2541
2520
|
{
|
2542
|
-
GGML_ASSERT(ne00 % 4
|
2521
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2522
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2523
|
+
|
2543
2524
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2544
2525
|
|
2545
|
-
|
2526
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2546
2527
|
|
2547
|
-
|
2548
|
-
GGML_ASSERT(src3);
|
2528
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2549
2529
|
|
2550
2530
|
size_t offs_src3 = 0;
|
2551
2531
|
|
@@ -2555,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2555
2535
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2556
2536
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2557
2537
|
|
2538
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
+
|
2558
2543
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2559
|
-
|
2544
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2560
2545
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2561
2546
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2562
2547
|
|
@@ -2568,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2568
2553
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2569
2554
|
|
2570
2555
|
float scale;
|
2571
|
-
|
2556
|
+
float max_bias;
|
2557
|
+
|
2558
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2559
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2560
|
+
|
2561
|
+
const uint32_t n_head = src0->ne[2];
|
2562
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2563
|
+
|
2564
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2565
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2572
2566
|
|
2573
2567
|
id<MTLComputePipelineState> pipeline = nil;
|
2574
2568
|
|
@@ -2605,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2605
2599
|
}
|
2606
2600
|
|
2607
2601
|
[encoder setComputePipelineState:pipeline];
|
2608
|
-
[encoder setBuffer:id_src0
|
2609
|
-
[encoder setBuffer:id_src1
|
2610
|
-
[encoder setBuffer:id_src2
|
2611
|
-
|
2612
|
-
|
2613
|
-
|
2614
|
-
|
2615
|
-
|
2616
|
-
[encoder
|
2617
|
-
[encoder setBytes:&
|
2618
|
-
[encoder setBytes:&
|
2619
|
-
[encoder setBytes:&
|
2620
|
-
[encoder setBytes:&
|
2621
|
-
[encoder setBytes:&
|
2622
|
-
[encoder setBytes:&
|
2623
|
-
[encoder setBytes:&
|
2624
|
-
[encoder setBytes:&
|
2625
|
-
[encoder setBytes:&
|
2626
|
-
[encoder setBytes:&nb11
|
2627
|
-
[encoder setBytes:&nb12
|
2628
|
-
[encoder setBytes:&nb13
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&scale
|
2602
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2603
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2604
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2605
|
+
if (id_src3) {
|
2606
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2607
|
+
} else {
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2609
|
+
}
|
2610
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2611
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2612
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2613
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2614
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2615
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2616
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2617
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2618
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2619
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2620
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2621
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2622
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2623
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2624
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2625
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2626
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2627
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2628
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2629
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2630
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2631
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2632
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2633
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2636
2634
|
|
2637
2635
|
if (!use_vec_kernel) {
|
2638
2636
|
// half8x8 kernel
|
@@ -2840,7 +2838,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_
|
|
2840
2838
|
ggml_backend_metal_free_device();
|
2841
2839
|
|
2842
2840
|
if (ctx->owned) {
|
2841
|
+
#if TARGET_OS_OSX
|
2842
|
+
vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
|
2843
|
+
#else
|
2843
2844
|
free(ctx->all_data);
|
2845
|
+
#endif
|
2844
2846
|
}
|
2845
2847
|
|
2846
2848
|
free(ctx);
|
@@ -2944,14 +2946,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|
2944
2946
|
ctx->owned = true;
|
2945
2947
|
ctx->n_buffers = 1;
|
2946
2948
|
|
2947
|
-
ctx->
|
2948
|
-
|
2949
|
-
|
2950
|
-
|
2951
|
-
|
2952
|
-
|
2949
|
+
if (ctx->all_data != NULL) {
|
2950
|
+
ctx->buffers[0].data = ctx->all_data;
|
2951
|
+
ctx->buffers[0].size = size;
|
2952
|
+
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
2953
|
+
length:size_aligned
|
2954
|
+
options:MTLResourceStorageModeShared
|
2955
|
+
deallocator:nil];
|
2956
|
+
}
|
2953
2957
|
|
2954
|
-
if (ctx->buffers[0].metal == nil) {
|
2958
|
+
if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
|
2955
2959
|
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2956
2960
|
free(ctx);
|
2957
2961
|
ggml_backend_metal_free_device();
|