llama_cpp 0.15.1 → 0.15.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +3 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +15 -7
- data/vendor/tmp/llama.cpp/ggml-impl.h +7 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +114 -125
- data/vendor/tmp/llama.cpp/ggml-metal.metal +86 -109
- data/vendor/tmp/llama.cpp/ggml-quants.c +2202 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +24 -143
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +4 -2
- data/vendor/tmp/llama.cpp/ggml.c +726 -646
- data/vendor/tmp/llama.cpp/ggml.h +28 -17
- data/vendor/tmp/llama.cpp/llama.cpp +478 -281
- data/vendor/tmp/llama.cpp/llama.h +3 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 30dd4c29b86098faf7c78de5fa8e57021b631bb5eb3d14c93f63f1d186383ab8
|
4
|
+
data.tar.gz: b011d891f1cd725f84821428a8db24004b52c9614e785f493f721f7abde71029
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6c1628f93762747688f802db8593946e8581c869f63c610669b45759f644b3d19b061825b788e328b6b984977112837586ed398b6118a8f8e5f0c7f6fd0eb2dd
|
7
|
+
data.tar.gz: 2f8c3d9f1e6c0f6db7e0682995c8d34179d5405d32784bf00f04a3408cb5bf4c95557bfa1692026f8d3dc9e672d6b15dec5d33cbd76ddc1d94e5ec964a9d0409
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [[0.15.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.1...v0.15.2)] - 2024-05-18
|
2
|
+
|
3
|
+
- Bump llama.cpp from b2839 to b2917.
|
4
|
+
|
5
|
+
Implementation binding for rpc_servers in llama_model_params has been skipped.
|
6
|
+
|
1
7
|
## [[0.15.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.0...v0.15.1)] - 2024-05-11
|
2
8
|
|
3
9
|
- Bump llama.cpp from b2781 to b2839.
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.15.
|
6
|
+
VERSION = '0.15.2'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = '
|
9
|
+
LLAMA_CPP_VERSION = 'b2917'
|
10
10
|
end
|
@@ -562,10 +562,10 @@ endif # LLAMA_VULKAN
|
|
562
562
|
ifdef LLAMA_HIPBLAS
|
563
563
|
ifeq ($(wildcard /opt/rocm),)
|
564
564
|
ROCM_PATH ?= /usr
|
565
|
-
|
565
|
+
AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
|
566
566
|
else
|
567
567
|
ROCM_PATH ?= /opt/rocm
|
568
|
-
|
568
|
+
AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
569
569
|
endif
|
570
570
|
HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc
|
571
571
|
LLAMA_CUDA_DMMV_X ?= 32
|
@@ -577,7 +577,7 @@ ifdef LLAMA_HIP_UMA
|
|
577
577
|
endif # LLAMA_HIP_UMA
|
578
578
|
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
579
579
|
MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
580
|
-
HIPFLAGS += $(addprefix --offload-arch=,$(
|
580
|
+
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
|
581
581
|
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
582
582
|
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
583
583
|
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
@@ -1182,9 +1182,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
|
1182
1182
|
static char * fmt_size(size_t size) {
|
1183
1183
|
static char buffer[128];
|
1184
1184
|
if (size >= 1024*1024) {
|
1185
|
-
|
1185
|
+
snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
|
1186
1186
|
} else {
|
1187
|
-
|
1187
|
+
snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
|
1188
1188
|
}
|
1189
1189
|
return buffer;
|
1190
1190
|
}
|
@@ -1895,7 +1895,6 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
|
|
1895
1895
|
|
1896
1896
|
tensor->buffer = buffer;
|
1897
1897
|
tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
|
1898
|
-
tensor->backend = tensor->view_src->backend;
|
1899
1898
|
ggml_backend_buffer_init_tensor(buffer, tensor);
|
1900
1899
|
}
|
1901
1900
|
|
@@ -4,7 +4,6 @@
|
|
4
4
|
|
5
5
|
#include "ggml-cuda/common.cuh"
|
6
6
|
#include "ggml-cuda/acc.cuh"
|
7
|
-
#include "ggml-cuda/alibi.cuh"
|
8
7
|
#include "ggml-cuda/arange.cuh"
|
9
8
|
#include "ggml-cuda/argsort.cuh"
|
10
9
|
#include "ggml-cuda/binbcast.cuh"
|
@@ -2205,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
2205
2204
|
case GGML_UNARY_OP_RELU:
|
2206
2205
|
ggml_cuda_op_relu(ctx, dst);
|
2207
2206
|
break;
|
2207
|
+
case GGML_UNARY_OP_SIGMOID:
|
2208
|
+
ggml_cuda_op_sigmoid(ctx, dst);
|
2209
|
+
break;
|
2208
2210
|
case GGML_UNARY_OP_HARDSIGMOID:
|
2209
2211
|
ggml_cuda_op_hardsigmoid(ctx, dst);
|
2210
2212
|
break;
|
@@ -2277,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
2277
2279
|
case GGML_OP_ROPE:
|
2278
2280
|
ggml_cuda_op_rope(ctx, dst);
|
2279
2281
|
break;
|
2280
|
-
case GGML_OP_ALIBI:
|
2281
|
-
ggml_cuda_op_alibi(ctx, dst);
|
2282
|
-
break;
|
2283
2282
|
case GGML_OP_IM2COL:
|
2284
2283
|
ggml_cuda_op_im2col(ctx, dst);
|
2285
2284
|
break;
|
@@ -2559,7 +2558,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2559
2558
|
}
|
2560
2559
|
|
2561
2560
|
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
2562
|
-
if (cuda_graph_update_required) {
|
2561
|
+
if (use_cuda_graph && cuda_graph_update_required) {
|
2563
2562
|
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
2564
2563
|
} else {
|
2565
2564
|
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
@@ -2714,12 +2713,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2714
2713
|
}
|
2715
2714
|
|
2716
2715
|
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
2716
|
+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
2717
2717
|
switch (op->op) {
|
2718
2718
|
case GGML_OP_UNARY:
|
2719
2719
|
switch (ggml_get_unary_op(op)) {
|
2720
2720
|
case GGML_UNARY_OP_GELU:
|
2721
2721
|
case GGML_UNARY_OP_SILU:
|
2722
2722
|
case GGML_UNARY_OP_RELU:
|
2723
|
+
case GGML_UNARY_OP_SIGMOID:
|
2723
2724
|
case GGML_UNARY_OP_HARDSIGMOID:
|
2724
2725
|
case GGML_UNARY_OP_HARDSWISH:
|
2725
2726
|
case GGML_UNARY_OP_GELU_QUICK:
|
@@ -2829,7 +2830,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2829
2830
|
case GGML_OP_DIAG_MASK_INF:
|
2830
2831
|
case GGML_OP_SOFT_MAX:
|
2831
2832
|
case GGML_OP_ROPE:
|
2832
|
-
case GGML_OP_ALIBI:
|
2833
2833
|
case GGML_OP_IM2COL:
|
2834
2834
|
case GGML_OP_POOL_2D:
|
2835
2835
|
case GGML_OP_SUM_ROWS:
|
@@ -2841,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2841
2841
|
case GGML_OP_ARANGE:
|
2842
2842
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
2843
2843
|
case GGML_OP_LEAKY_RELU:
|
2844
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
2845
2844
|
return true;
|
2845
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
2846
|
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
2847
|
+
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
|
2848
|
+
#else
|
2849
|
+
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
|
2850
|
+
return true;
|
2851
|
+
}
|
2852
|
+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
|
2853
|
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
2846
2854
|
default:
|
2847
2855
|
return false;
|
2848
2856
|
}
|
@@ -120,9 +120,16 @@ extern "C" {
|
|
120
120
|
#ifndef __F16C__
|
121
121
|
#define __F16C__
|
122
122
|
#endif
|
123
|
+
#endif
|
124
|
+
|
125
|
+
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
126
|
+
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
123
127
|
#ifndef __SSE3__
|
124
128
|
#define __SSE3__
|
125
129
|
#endif
|
130
|
+
#ifndef __SSSE3__
|
131
|
+
#define __SSSE3__
|
132
|
+
#endif
|
126
133
|
#endif
|
127
134
|
|
128
135
|
// 16-bit float
|
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1559
1559
|
case GGML_OP_SOFT_MAX:
|
1560
1560
|
{
|
1561
1561
|
float scale;
|
1562
|
-
|
1562
|
+
float max_bias;
|
1563
1563
|
|
1564
|
-
|
1564
|
+
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
1565
|
+
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
1566
|
+
|
1567
|
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
1565
1568
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
1566
1569
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
1567
|
-
|
1570
|
+
|
1571
|
+
#pragma message("TODO: add ALiBi support")
|
1572
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
1573
|
+
GGML_ASSERT(max_bias == 0.0f);
|
1568
1574
|
|
1569
1575
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
1570
1576
|
} break;
|
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -494,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
494
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
495
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
496
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
497
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
498
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
499
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -623,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
623
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
624
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
625
626
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
626
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
627
627
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
628
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
629
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -633,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
633
633
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
634
634
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
635
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
636
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
637
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
638
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
639
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
640
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
641
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
642
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
643
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
645
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -732,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
732
732
|
switch (ggml_get_unary_op(op)) {
|
733
733
|
case GGML_UNARY_OP_TANH:
|
734
734
|
case GGML_UNARY_OP_RELU:
|
735
|
+
case GGML_UNARY_OP_SIGMOID:
|
735
736
|
case GGML_UNARY_OP_GELU:
|
736
737
|
case GGML_UNARY_OP_GELU_QUICK:
|
737
738
|
case GGML_UNARY_OP_SILU:
|
@@ -759,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
759
760
|
case GGML_OP_GROUP_NORM:
|
760
761
|
return ctx->support_simdgroup_reduction;
|
761
762
|
case GGML_OP_NORM:
|
762
|
-
case GGML_OP_ALIBI:
|
763
763
|
case GGML_OP_ROPE:
|
764
764
|
case GGML_OP_IM2COL:
|
765
765
|
return true;
|
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
772
772
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
773
773
|
case GGML_OP_ARGSORT:
|
774
774
|
case GGML_OP_LEAKY_RELU:
|
775
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
776
775
|
return true;
|
776
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
777
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
777
778
|
case GGML_OP_MUL_MAT:
|
778
779
|
case GGML_OP_MUL_MAT_ID:
|
779
780
|
return ctx->support_simdgroup_reduction &&
|
@@ -1194,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1194
1195
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1196
|
} break;
|
1196
1197
|
case GGML_OP_CLAMP:
|
1197
|
-
|
1198
|
-
|
1198
|
+
{
|
1199
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1199
1200
|
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1201
|
+
float min;
|
1202
|
+
float max;
|
1203
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1204
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1204
1205
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1206
|
+
[encoder setComputePipelineState:pipeline];
|
1207
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1208
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1209
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1210
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1210
1211
|
|
1211
|
-
|
1212
|
+
const int64_t n = ggml_nelements(dst);
|
1212
1213
|
|
1213
|
-
|
1214
|
-
|
1214
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1215
|
+
} break;
|
1215
1216
|
case GGML_OP_UNARY:
|
1216
1217
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1217
1218
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1239,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1239
1240
|
|
1240
1241
|
const int64_t n = ggml_nelements(dst);
|
1241
1242
|
|
1243
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1244
|
+
} break;
|
1245
|
+
case GGML_UNARY_OP_SIGMOID:
|
1246
|
+
{
|
1247
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1248
|
+
|
1249
|
+
[encoder setComputePipelineState:pipeline];
|
1250
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1251
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1252
|
+
|
1253
|
+
const int64_t n = ggml_nelements(dst);
|
1254
|
+
|
1242
1255
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1243
1256
|
} break;
|
1244
1257
|
case GGML_UNARY_OP_GELU:
|
@@ -1357,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1357
1370
|
case GGML_OP_SOFT_MAX:
|
1358
1371
|
{
|
1359
1372
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1360
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1361
1373
|
|
1362
1374
|
int nth = 32; // SIMD width
|
1363
1375
|
|
1364
1376
|
id<MTLComputePipelineState> pipeline = nil;
|
1365
1377
|
|
1366
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1378
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1367
1379
|
|
1368
1380
|
if (ne00%4 == 0) {
|
1369
|
-
while (nth < ne00/4 && nth < 256) {
|
1381
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1370
1382
|
nth *= 2;
|
1371
1383
|
}
|
1372
1384
|
if (use_f16) {
|
@@ -1375,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1375
1387
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1376
1388
|
}
|
1377
1389
|
} else {
|
1378
|
-
while (nth < ne00 && nth <
|
1390
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1379
1391
|
nth *= 2;
|
1380
1392
|
}
|
1381
1393
|
if (use_f16) {
|
@@ -1394,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1394
1406
|
const int64_t nrows_x = ggml_nrows(src0);
|
1395
1407
|
const int64_t nrows_y = src0->ne[1];
|
1396
1408
|
|
1397
|
-
const uint32_t
|
1398
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1409
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1410
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1399
1411
|
|
1400
1412
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1401
1413
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1407,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1407
1419
|
} else {
|
1408
1420
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1409
1421
|
}
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
[encoder
|
1416
|
-
[encoder setBytes:&
|
1417
|
-
[encoder setBytes:&
|
1418
|
-
[encoder setBytes:&
|
1419
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1420
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1421
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1422
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1423
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1422
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1423
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1424
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1425
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1426
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1427
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1428
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1429
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1430
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1424
1431
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1425
1432
|
|
1426
1433
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -2225,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2225
2232
|
|
2226
2233
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2234
|
} break;
|
2228
|
-
case GGML_OP_ALIBI:
|
2229
|
-
{
|
2230
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2231
|
-
|
2232
|
-
const int nth = MIN(1024, ne00);
|
2233
|
-
|
2234
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2235
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2236
|
-
|
2237
|
-
float max_bias;
|
2238
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2239
|
-
|
2240
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2241
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2242
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2243
|
-
|
2244
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2245
|
-
|
2246
|
-
[encoder setComputePipelineState:pipeline];
|
2247
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2248
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2249
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2250
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2251
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2252
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2253
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2254
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2255
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2256
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2257
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2258
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2259
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2260
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2261
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2262
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2263
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2264
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2265
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2266
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2267
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2268
|
-
|
2269
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2270
|
-
} break;
|
2271
2235
|
case GGML_OP_ROPE:
|
2272
2236
|
{
|
2273
2237
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2389,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2389
2353
|
{
|
2390
2354
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2391
2355
|
|
2392
|
-
const
|
2356
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2357
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2358
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2359
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2393
2360
|
|
2394
2361
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2395
2362
|
|
@@ -2412,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2412
2379
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2413
2380
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2414
2381
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2415
|
-
[encoder setBytes:&
|
2382
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2383
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2384
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2385
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2416
2386
|
|
2417
2387
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2418
2388
|
|
@@ -2548,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2548
2518
|
} break;
|
2549
2519
|
case GGML_OP_FLASH_ATTN_EXT:
|
2550
2520
|
{
|
2551
|
-
GGML_ASSERT(ne00 % 4
|
2521
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2522
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2523
|
+
|
2552
2524
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2553
2525
|
|
2554
|
-
|
2526
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2555
2527
|
|
2556
|
-
|
2557
|
-
GGML_ASSERT(src3);
|
2528
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2558
2529
|
|
2559
2530
|
size_t offs_src3 = 0;
|
2560
2531
|
|
@@ -2564,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2564
2535
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2565
2536
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2566
2537
|
|
2538
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
+
|
2567
2543
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2568
|
-
|
2544
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2569
2545
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2570
2546
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2571
2547
|
|
@@ -2577,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2577
2553
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2578
2554
|
|
2579
2555
|
float scale;
|
2580
|
-
|
2556
|
+
float max_bias;
|
2557
|
+
|
2558
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2559
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2560
|
+
|
2561
|
+
const uint32_t n_head = src0->ne[2];
|
2562
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2563
|
+
|
2564
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2565
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2581
2566
|
|
2582
2567
|
id<MTLComputePipelineState> pipeline = nil;
|
2583
2568
|
|
@@ -2614,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2614
2599
|
}
|
2615
2600
|
|
2616
2601
|
[encoder setComputePipelineState:pipeline];
|
2617
|
-
[encoder setBuffer:id_src0
|
2618
|
-
[encoder setBuffer:id_src1
|
2619
|
-
[encoder setBuffer:id_src2
|
2620
|
-
|
2621
|
-
|
2622
|
-
|
2623
|
-
|
2624
|
-
|
2625
|
-
[encoder
|
2626
|
-
[encoder setBytes:&
|
2627
|
-
[encoder setBytes:&
|
2628
|
-
[encoder setBytes:&
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&nb11
|
2636
|
-
[encoder setBytes:&nb12
|
2637
|
-
[encoder setBytes:&nb13
|
2638
|
-
[encoder setBytes:&
|
2639
|
-
[encoder setBytes:&
|
2640
|
-
[encoder setBytes:&
|
2641
|
-
[encoder setBytes:&
|
2642
|
-
[encoder setBytes:&
|
2643
|
-
[encoder setBytes:&
|
2644
|
-
[encoder setBytes:&scale
|
2602
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2603
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2604
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2605
|
+
if (id_src3) {
|
2606
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2607
|
+
} else {
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2609
|
+
}
|
2610
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2611
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2612
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2613
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2614
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2615
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2616
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2617
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2618
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2619
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2620
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2621
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2622
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2623
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2624
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2625
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2626
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2627
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2628
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2629
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2630
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2631
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2632
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2633
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2645
2634
|
|
2646
2635
|
if (!use_vec_kernel) {
|
2647
2636
|
// half8x8 kernel
|