llama_cpp 0.15.1 → 0.15.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +3 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +15 -7
- data/vendor/tmp/llama.cpp/ggml-impl.h +7 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +114 -125
- data/vendor/tmp/llama.cpp/ggml-metal.metal +86 -109
- data/vendor/tmp/llama.cpp/ggml-quants.c +2202 -28
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1032 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +24 -143
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +4 -2
- data/vendor/tmp/llama.cpp/ggml.c +726 -646
- data/vendor/tmp/llama.cpp/ggml.h +28 -17
- data/vendor/tmp/llama.cpp/llama.cpp +478 -281
- data/vendor/tmp/llama.cpp/llama.h +3 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 30dd4c29b86098faf7c78de5fa8e57021b631bb5eb3d14c93f63f1d186383ab8
|
4
|
+
data.tar.gz: b011d891f1cd725f84821428a8db24004b52c9614e785f493f721f7abde71029
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 6c1628f93762747688f802db8593946e8581c869f63c610669b45759f644b3d19b061825b788e328b6b984977112837586ed398b6118a8f8e5f0c7f6fd0eb2dd
|
7
|
+
data.tar.gz: 2f8c3d9f1e6c0f6db7e0682995c8d34179d5405d32784bf00f04a3408cb5bf4c95557bfa1692026f8d3dc9e672d6b15dec5d33cbd76ddc1d94e5ec964a9d0409
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [[0.15.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.1...v0.15.2)] - 2024-05-18
|
2
|
+
|
3
|
+
- Bump llama.cpp from b2839 to b2917.
|
4
|
+
|
5
|
+
Implementation binding for rpc_servers in llama_model_params has been skipped.
|
6
|
+
|
1
7
|
## [[0.15.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.0...v0.15.1)] - 2024-05-11
|
2
8
|
|
3
9
|
- Bump llama.cpp from b2781 to b2839.
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.15.
|
6
|
+
VERSION = '0.15.2'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = '
|
9
|
+
LLAMA_CPP_VERSION = 'b2917'
|
10
10
|
end
|
@@ -562,10 +562,10 @@ endif # LLAMA_VULKAN
|
|
562
562
|
ifdef LLAMA_HIPBLAS
|
563
563
|
ifeq ($(wildcard /opt/rocm),)
|
564
564
|
ROCM_PATH ?= /usr
|
565
|
-
|
565
|
+
AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
|
566
566
|
else
|
567
567
|
ROCM_PATH ?= /opt/rocm
|
568
|
-
|
568
|
+
AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
569
569
|
endif
|
570
570
|
HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc
|
571
571
|
LLAMA_CUDA_DMMV_X ?= 32
|
@@ -577,7 +577,7 @@ ifdef LLAMA_HIP_UMA
|
|
577
577
|
endif # LLAMA_HIP_UMA
|
578
578
|
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
579
579
|
MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
580
|
-
HIPFLAGS += $(addprefix --offload-arch=,$(
|
580
|
+
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
|
581
581
|
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
582
582
|
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
583
583
|
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
@@ -1182,9 +1182,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
|
1182
1182
|
static char * fmt_size(size_t size) {
|
1183
1183
|
static char buffer[128];
|
1184
1184
|
if (size >= 1024*1024) {
|
1185
|
-
|
1185
|
+
snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
|
1186
1186
|
} else {
|
1187
|
-
|
1187
|
+
snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
|
1188
1188
|
}
|
1189
1189
|
return buffer;
|
1190
1190
|
}
|
@@ -1895,7 +1895,6 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
|
|
1895
1895
|
|
1896
1896
|
tensor->buffer = buffer;
|
1897
1897
|
tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
|
1898
|
-
tensor->backend = tensor->view_src->backend;
|
1899
1898
|
ggml_backend_buffer_init_tensor(buffer, tensor);
|
1900
1899
|
}
|
1901
1900
|
|
@@ -4,7 +4,6 @@
|
|
4
4
|
|
5
5
|
#include "ggml-cuda/common.cuh"
|
6
6
|
#include "ggml-cuda/acc.cuh"
|
7
|
-
#include "ggml-cuda/alibi.cuh"
|
8
7
|
#include "ggml-cuda/arange.cuh"
|
9
8
|
#include "ggml-cuda/argsort.cuh"
|
10
9
|
#include "ggml-cuda/binbcast.cuh"
|
@@ -2205,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
2205
2204
|
case GGML_UNARY_OP_RELU:
|
2206
2205
|
ggml_cuda_op_relu(ctx, dst);
|
2207
2206
|
break;
|
2207
|
+
case GGML_UNARY_OP_SIGMOID:
|
2208
|
+
ggml_cuda_op_sigmoid(ctx, dst);
|
2209
|
+
break;
|
2208
2210
|
case GGML_UNARY_OP_HARDSIGMOID:
|
2209
2211
|
ggml_cuda_op_hardsigmoid(ctx, dst);
|
2210
2212
|
break;
|
@@ -2277,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
2277
2279
|
case GGML_OP_ROPE:
|
2278
2280
|
ggml_cuda_op_rope(ctx, dst);
|
2279
2281
|
break;
|
2280
|
-
case GGML_OP_ALIBI:
|
2281
|
-
ggml_cuda_op_alibi(ctx, dst);
|
2282
|
-
break;
|
2283
2282
|
case GGML_OP_IM2COL:
|
2284
2283
|
ggml_cuda_op_im2col(ctx, dst);
|
2285
2284
|
break;
|
@@ -2559,7 +2558,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2559
2558
|
}
|
2560
2559
|
|
2561
2560
|
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
2562
|
-
if (cuda_graph_update_required) {
|
2561
|
+
if (use_cuda_graph && cuda_graph_update_required) {
|
2563
2562
|
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
2564
2563
|
} else {
|
2565
2564
|
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
@@ -2714,12 +2713,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2714
2713
|
}
|
2715
2714
|
|
2716
2715
|
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
2716
|
+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
2717
2717
|
switch (op->op) {
|
2718
2718
|
case GGML_OP_UNARY:
|
2719
2719
|
switch (ggml_get_unary_op(op)) {
|
2720
2720
|
case GGML_UNARY_OP_GELU:
|
2721
2721
|
case GGML_UNARY_OP_SILU:
|
2722
2722
|
case GGML_UNARY_OP_RELU:
|
2723
|
+
case GGML_UNARY_OP_SIGMOID:
|
2723
2724
|
case GGML_UNARY_OP_HARDSIGMOID:
|
2724
2725
|
case GGML_UNARY_OP_HARDSWISH:
|
2725
2726
|
case GGML_UNARY_OP_GELU_QUICK:
|
@@ -2829,7 +2830,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2829
2830
|
case GGML_OP_DIAG_MASK_INF:
|
2830
2831
|
case GGML_OP_SOFT_MAX:
|
2831
2832
|
case GGML_OP_ROPE:
|
2832
|
-
case GGML_OP_ALIBI:
|
2833
2833
|
case GGML_OP_IM2COL:
|
2834
2834
|
case GGML_OP_POOL_2D:
|
2835
2835
|
case GGML_OP_SUM_ROWS:
|
@@ -2841,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2841
2841
|
case GGML_OP_ARANGE:
|
2842
2842
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
2843
2843
|
case GGML_OP_LEAKY_RELU:
|
2844
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
2845
2844
|
return true;
|
2845
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
2846
|
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
2847
|
+
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
|
2848
|
+
#else
|
2849
|
+
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
|
2850
|
+
return true;
|
2851
|
+
}
|
2852
|
+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
|
2853
|
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
2846
2854
|
default:
|
2847
2855
|
return false;
|
2848
2856
|
}
|
@@ -120,9 +120,16 @@ extern "C" {
|
|
120
120
|
#ifndef __F16C__
|
121
121
|
#define __F16C__
|
122
122
|
#endif
|
123
|
+
#endif
|
124
|
+
|
125
|
+
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
126
|
+
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
123
127
|
#ifndef __SSE3__
|
124
128
|
#define __SSE3__
|
125
129
|
#endif
|
130
|
+
#ifndef __SSSE3__
|
131
|
+
#define __SSSE3__
|
132
|
+
#endif
|
126
133
|
#endif
|
127
134
|
|
128
135
|
// 16-bit float
|
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1559
1559
|
case GGML_OP_SOFT_MAX:
|
1560
1560
|
{
|
1561
1561
|
float scale;
|
1562
|
-
|
1562
|
+
float max_bias;
|
1563
1563
|
|
1564
|
-
|
1564
|
+
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
1565
|
+
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
1566
|
+
|
1567
|
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
1565
1568
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
1566
1569
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
1567
|
-
|
1570
|
+
|
1571
|
+
#pragma message("TODO: add ALiBi support")
|
1572
|
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
1573
|
+
GGML_ASSERT(max_bias == 0.0f);
|
1568
1574
|
|
1569
1575
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
1570
1576
|
} break;
|
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -494,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
494
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
495
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
496
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
497
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
498
499
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
499
500
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -623,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
623
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
624
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
625
626
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
626
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
627
627
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
628
628
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
629
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -633,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
633
633
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
634
634
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
635
635
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
636
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
637
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
638
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
639
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
640
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
641
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
642
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
643
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
645
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -732,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
732
732
|
switch (ggml_get_unary_op(op)) {
|
733
733
|
case GGML_UNARY_OP_TANH:
|
734
734
|
case GGML_UNARY_OP_RELU:
|
735
|
+
case GGML_UNARY_OP_SIGMOID:
|
735
736
|
case GGML_UNARY_OP_GELU:
|
736
737
|
case GGML_UNARY_OP_GELU_QUICK:
|
737
738
|
case GGML_UNARY_OP_SILU:
|
@@ -759,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
759
760
|
case GGML_OP_GROUP_NORM:
|
760
761
|
return ctx->support_simdgroup_reduction;
|
761
762
|
case GGML_OP_NORM:
|
762
|
-
case GGML_OP_ALIBI:
|
763
763
|
case GGML_OP_ROPE:
|
764
764
|
case GGML_OP_IM2COL:
|
765
765
|
return true;
|
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
772
772
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
773
773
|
case GGML_OP_ARGSORT:
|
774
774
|
case GGML_OP_LEAKY_RELU:
|
775
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
776
775
|
return true;
|
776
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
777
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
777
778
|
case GGML_OP_MUL_MAT:
|
778
779
|
case GGML_OP_MUL_MAT_ID:
|
779
780
|
return ctx->support_simdgroup_reduction &&
|
@@ -1194,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1194
1195
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1196
|
} break;
|
1196
1197
|
case GGML_OP_CLAMP:
|
1197
|
-
|
1198
|
-
|
1198
|
+
{
|
1199
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1199
1200
|
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1201
|
+
float min;
|
1202
|
+
float max;
|
1203
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1204
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1204
1205
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1206
|
+
[encoder setComputePipelineState:pipeline];
|
1207
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1208
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1209
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1210
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1210
1211
|
|
1211
|
-
|
1212
|
+
const int64_t n = ggml_nelements(dst);
|
1212
1213
|
|
1213
|
-
|
1214
|
-
|
1214
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1215
|
+
} break;
|
1215
1216
|
case GGML_OP_UNARY:
|
1216
1217
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1217
1218
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1239,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1239
1240
|
|
1240
1241
|
const int64_t n = ggml_nelements(dst);
|
1241
1242
|
|
1243
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1244
|
+
} break;
|
1245
|
+
case GGML_UNARY_OP_SIGMOID:
|
1246
|
+
{
|
1247
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1248
|
+
|
1249
|
+
[encoder setComputePipelineState:pipeline];
|
1250
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1251
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1252
|
+
|
1253
|
+
const int64_t n = ggml_nelements(dst);
|
1254
|
+
|
1242
1255
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1243
1256
|
} break;
|
1244
1257
|
case GGML_UNARY_OP_GELU:
|
@@ -1357,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1357
1370
|
case GGML_OP_SOFT_MAX:
|
1358
1371
|
{
|
1359
1372
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1360
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1361
1373
|
|
1362
1374
|
int nth = 32; // SIMD width
|
1363
1375
|
|
1364
1376
|
id<MTLComputePipelineState> pipeline = nil;
|
1365
1377
|
|
1366
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1378
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1367
1379
|
|
1368
1380
|
if (ne00%4 == 0) {
|
1369
|
-
while (nth < ne00/4 && nth < 256) {
|
1381
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1370
1382
|
nth *= 2;
|
1371
1383
|
}
|
1372
1384
|
if (use_f16) {
|
@@ -1375,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1375
1387
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1376
1388
|
}
|
1377
1389
|
} else {
|
1378
|
-
while (nth < ne00 && nth <
|
1390
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1379
1391
|
nth *= 2;
|
1380
1392
|
}
|
1381
1393
|
if (use_f16) {
|
@@ -1394,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1394
1406
|
const int64_t nrows_x = ggml_nrows(src0);
|
1395
1407
|
const int64_t nrows_y = src0->ne[1];
|
1396
1408
|
|
1397
|
-
const uint32_t
|
1398
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1409
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1410
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1399
1411
|
|
1400
1412
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1401
1413
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1407,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1407
1419
|
} else {
|
1408
1420
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1409
1421
|
}
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
[encoder
|
1416
|
-
[encoder setBytes:&
|
1417
|
-
[encoder setBytes:&
|
1418
|
-
[encoder setBytes:&
|
1419
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1420
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1421
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1422
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1423
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1422
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1423
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1424
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1425
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1426
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1427
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1428
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1429
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1430
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1424
1431
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1425
1432
|
|
1426
1433
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -2225,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2225
2232
|
|
2226
2233
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2234
|
} break;
|
2228
|
-
case GGML_OP_ALIBI:
|
2229
|
-
{
|
2230
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2231
|
-
|
2232
|
-
const int nth = MIN(1024, ne00);
|
2233
|
-
|
2234
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2235
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2236
|
-
|
2237
|
-
float max_bias;
|
2238
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2239
|
-
|
2240
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2241
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2242
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2243
|
-
|
2244
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2245
|
-
|
2246
|
-
[encoder setComputePipelineState:pipeline];
|
2247
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2248
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2249
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2250
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2251
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2252
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2253
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2254
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2255
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2256
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2257
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2258
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2259
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2260
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2261
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2262
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2263
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2264
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2265
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2266
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2267
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2268
|
-
|
2269
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2270
|
-
} break;
|
2271
2235
|
case GGML_OP_ROPE:
|
2272
2236
|
{
|
2273
2237
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2389,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2389
2353
|
{
|
2390
2354
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2391
2355
|
|
2392
|
-
const
|
2356
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2357
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2358
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2359
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2393
2360
|
|
2394
2361
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2395
2362
|
|
@@ -2412,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2412
2379
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2413
2380
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2414
2381
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2415
|
-
[encoder setBytes:&
|
2382
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2383
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2384
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2385
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2416
2386
|
|
2417
2387
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2418
2388
|
|
@@ -2548,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2548
2518
|
} break;
|
2549
2519
|
case GGML_OP_FLASH_ATTN_EXT:
|
2550
2520
|
{
|
2551
|
-
GGML_ASSERT(ne00 % 4
|
2521
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2522
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2523
|
+
|
2552
2524
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2553
2525
|
|
2554
|
-
|
2526
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2555
2527
|
|
2556
|
-
|
2557
|
-
GGML_ASSERT(src3);
|
2528
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2558
2529
|
|
2559
2530
|
size_t offs_src3 = 0;
|
2560
2531
|
|
@@ -2564,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2564
2535
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2565
2536
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2566
2537
|
|
2538
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
+
|
2567
2543
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2568
|
-
|
2544
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2569
2545
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2570
2546
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2571
2547
|
|
@@ -2577,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2577
2553
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2578
2554
|
|
2579
2555
|
float scale;
|
2580
|
-
|
2556
|
+
float max_bias;
|
2557
|
+
|
2558
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2559
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2560
|
+
|
2561
|
+
const uint32_t n_head = src0->ne[2];
|
2562
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2563
|
+
|
2564
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2565
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2581
2566
|
|
2582
2567
|
id<MTLComputePipelineState> pipeline = nil;
|
2583
2568
|
|
@@ -2614,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2614
2599
|
}
|
2615
2600
|
|
2616
2601
|
[encoder setComputePipelineState:pipeline];
|
2617
|
-
[encoder setBuffer:id_src0
|
2618
|
-
[encoder setBuffer:id_src1
|
2619
|
-
[encoder setBuffer:id_src2
|
2620
|
-
|
2621
|
-
|
2622
|
-
|
2623
|
-
|
2624
|
-
|
2625
|
-
[encoder
|
2626
|
-
[encoder setBytes:&
|
2627
|
-
[encoder setBytes:&
|
2628
|
-
[encoder setBytes:&
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&nb11
|
2636
|
-
[encoder setBytes:&nb12
|
2637
|
-
[encoder setBytes:&nb13
|
2638
|
-
[encoder setBytes:&
|
2639
|
-
[encoder setBytes:&
|
2640
|
-
[encoder setBytes:&
|
2641
|
-
[encoder setBytes:&
|
2642
|
-
[encoder setBytes:&
|
2643
|
-
[encoder setBytes:&
|
2644
|
-
[encoder setBytes:&scale
|
2602
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2603
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2604
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2605
|
+
if (id_src3) {
|
2606
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2607
|
+
} else {
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2609
|
+
}
|
2610
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2611
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2612
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2613
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2614
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2615
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2616
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2617
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2618
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2619
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2620
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2621
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2622
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2623
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2624
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2625
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2626
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2627
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2628
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2629
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2630
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2631
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2632
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2633
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2645
2634
|
|
2646
2635
|
if (!use_vec_kernel) {
|
2647
2636
|
// half8x8 kernel
|