@fugood/llama.node 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +1 -1
- package/src/LlamaContext.cpp +81 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -28,8 +28,10 @@
|
|
|
28
28
|
#include "shaderop_getrows_q4_0.h"
|
|
29
29
|
#include "shaderop_getrows_q4_1.h"
|
|
30
30
|
#include "shaderop_getrows_q6_k.h"
|
|
31
|
-
#include "
|
|
32
|
-
#include "
|
|
31
|
+
#include "shaderop_rope_norm_f16.h"
|
|
32
|
+
#include "shaderop_rope_norm_f32.h"
|
|
33
|
+
#include "shaderop_rope_neox_f16.h"
|
|
34
|
+
#include "shaderop_rope_neox_f32.h"
|
|
33
35
|
#include "shaderop_cpy_f16_f16.h"
|
|
34
36
|
#include "shaderop_cpy_f16_f32.h"
|
|
35
37
|
#include "shaderop_cpy_f32_f16.h"
|
|
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
|
|
|
345
347
|
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
|
|
346
348
|
vk::DescriptorPoolSize(
|
|
347
349
|
vk::DescriptorType::eStorageBuffer,
|
|
348
|
-
|
|
350
|
+
4 * size // Descriptor count is number of possible tensors to pass into an algorithm
|
|
349
351
|
)
|
|
350
352
|
};
|
|
351
353
|
|
|
@@ -788,7 +790,8 @@ static void ggml_vk_soft_max(
|
|
|
788
790
|
const std::shared_ptr<kp::Tensor>& out,
|
|
789
791
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
790
792
|
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
|
791
|
-
float scale
|
|
793
|
+
float scale, float max_bias, float m0, float m1,
|
|
794
|
+
uint32_t n_head_log2
|
|
792
795
|
) {
|
|
793
796
|
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
|
|
794
797
|
kp::shader_data::op_softmax_comp_spv_len);
|
|
@@ -796,12 +799,14 @@ static void ggml_vk_soft_max(
|
|
|
796
799
|
struct PushConstants {
|
|
797
800
|
uint32_t inAOff, inBOff, outOff;
|
|
798
801
|
int32_t ne00, ne01, ne02;
|
|
799
|
-
float scale;
|
|
802
|
+
float scale, max_bias, m0, m1;
|
|
803
|
+
uint32_t n_head_log2;
|
|
800
804
|
int32_t mask;
|
|
801
805
|
} pushConsts {
|
|
802
806
|
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
803
807
|
ne00, ne01, ne02,
|
|
804
|
-
scale,
|
|
808
|
+
scale, max_bias, m0, m1,
|
|
809
|
+
n_head_log2,
|
|
805
810
|
bool(inB)
|
|
806
811
|
};
|
|
807
812
|
|
|
@@ -911,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
|
|
|
911
916
|
const std::shared_ptr<kp::Tensor>& out,
|
|
912
917
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
913
918
|
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
914
|
-
uint32_t nb00, uint32_t nb01, uint32_t nb02,
|
|
919
|
+
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
915
920
|
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
916
|
-
uint32_t nb10, uint32_t nb11, uint32_t nb12,
|
|
921
|
+
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
917
922
|
int32_t ne0, int32_t ne1,
|
|
918
923
|
uint32_t r2, uint32_t r3
|
|
919
924
|
) {
|
|
@@ -923,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
|
|
|
923
928
|
struct PushConstants {
|
|
924
929
|
uint32_t inAOff, inBOff, outOff;
|
|
925
930
|
int32_t ne00, ne01, ne02;
|
|
926
|
-
uint32_t nb00, nb01, nb02;
|
|
931
|
+
uint32_t nb00, nb01, nb02, nb03;
|
|
927
932
|
int32_t ne10, ne11, ne12;
|
|
928
|
-
uint32_t nb10, nb11, nb12;
|
|
933
|
+
uint32_t nb10, nb11, nb12, nb13;
|
|
929
934
|
int32_t ne0, ne1;
|
|
930
935
|
uint32_t r2, r3;
|
|
931
936
|
} pushConsts {
|
|
932
937
|
safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
933
938
|
ne00, ne01, ne02,
|
|
934
|
-
nb00, nb01, nb02,
|
|
939
|
+
nb00, nb01, nb02, nb03,
|
|
935
940
|
ne10, ne11, ne12,
|
|
936
|
-
nb10, nb11, nb12,
|
|
941
|
+
nb10, nb11, nb12, nb13,
|
|
937
942
|
ne0, ne1,
|
|
938
943
|
r2, r3
|
|
939
944
|
};
|
|
@@ -1013,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
|
|
|
1013
1018
|
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1014
1019
|
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1015
1020
|
int32_t ne0, int32_t ne1,
|
|
1021
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1022
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1016
1023
|
uint32_t r2, uint32_t r3
|
|
1017
1024
|
) {
|
|
1018
1025
|
struct PushConstants {
|
|
@@ -1020,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
|
|
|
1020
1027
|
int32_t ne00, ne01, ne02;
|
|
1021
1028
|
int32_t ne10, ne12;
|
|
1022
1029
|
int32_t ne0, ne1;
|
|
1030
|
+
uint32_t nb01, nb02, nb03;
|
|
1031
|
+
uint32_t nb11, nb12, nb13;
|
|
1023
1032
|
uint32_t r2, r3;
|
|
1024
1033
|
} pushConsts {
|
|
1025
1034
|
safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1026
1035
|
ne00, ne01, ne02,
|
|
1027
1036
|
ne10, ne12,
|
|
1028
1037
|
ne0, ne1,
|
|
1038
|
+
nb01, nb02, nb03,
|
|
1039
|
+
nb11, nb12, nb13,
|
|
1029
1040
|
r2, r3
|
|
1030
1041
|
};
|
|
1031
1042
|
|
|
1032
1043
|
auto name = std::string(__func__) + "_" + suffix;
|
|
1033
1044
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1034
1045
|
if (!komputeManager()->hasAlgorithm(name)) {
|
|
1035
|
-
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
|
|
1046
|
+
const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
|
|
1036
1047
|
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
|
|
1037
1048
|
} else {
|
|
1038
1049
|
s_algo = komputeManager()->getAlgorithm(name);
|
|
@@ -1074,19 +1085,26 @@ static void ggml_vk_mul_mat_q4_k(
|
|
|
1074
1085
|
const std::shared_ptr<kp::Tensor>& inB,
|
|
1075
1086
|
const std::shared_ptr<kp::Tensor>& out,
|
|
1076
1087
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1077
|
-
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1078
|
-
int32_t
|
|
1079
|
-
int32_t
|
|
1088
|
+
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1089
|
+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1090
|
+
int32_t ne0, int32_t ne1,
|
|
1091
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1092
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1093
|
+
uint32_t r2, uint32_t r3
|
|
1080
1094
|
) {
|
|
1081
1095
|
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
|
1082
1096
|
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
|
1083
1097
|
|
|
1084
1098
|
struct PushConstants {
|
|
1085
1099
|
uint32_t inAOff, inBOff, outOff;
|
|
1086
|
-
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12
|
|
1100
|
+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
|
|
1101
|
+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
|
|
1102
|
+
uint32_t r2, r3;
|
|
1087
1103
|
} pushConsts {
|
|
1088
|
-
|
|
1089
|
-
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
|
|
1104
|
+
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1105
|
+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
|
|
1106
|
+
nb01, nb02, nb03, nb11, nb12, nb13,
|
|
1107
|
+
r2, r3
|
|
1090
1108
|
};
|
|
1091
1109
|
|
|
1092
1110
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
@@ -1108,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
|
|
|
1108
1126
|
const std::shared_ptr<kp::Tensor>& inB,
|
|
1109
1127
|
const std::shared_ptr<kp::Tensor>& out,
|
|
1110
1128
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1111
|
-
int32_t ne00, int32_t
|
|
1112
|
-
int32_t
|
|
1129
|
+
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1130
|
+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1131
|
+
int32_t ne0, int32_t ne1,
|
|
1132
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1133
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1134
|
+
uint32_t r2, uint32_t r3
|
|
1113
1135
|
) {
|
|
1114
1136
|
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
|
|
1115
1137
|
kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
|
|
1116
1138
|
|
|
1117
1139
|
struct PushConstants {
|
|
1118
1140
|
uint32_t inAOff, inBOff, outOff;
|
|
1119
|
-
int32_t ne00, ne10, ne0, ne1, ne01,
|
|
1141
|
+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
|
|
1142
|
+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
|
|
1143
|
+
uint32_t r2, r3;
|
|
1120
1144
|
} pushConsts {
|
|
1121
1145
|
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1122
|
-
ne00, ne10, ne0, ne1, ne01, ne12
|
|
1146
|
+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
|
|
1147
|
+
nb01, nb02, nb03, nb11, nb12, nb13,
|
|
1148
|
+
r2, r3
|
|
1123
1149
|
};
|
|
1124
1150
|
|
|
1125
1151
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1126
1152
|
if (!komputeManager()->hasAlgorithm(__func__)) {
|
|
1127
|
-
const uint32_t local_x =
|
|
1128
|
-
|
|
1153
|
+
const uint32_t local_x = 2;
|
|
1154
|
+
const uint32_t local_y = ggml_vk_current_device().subgroupSize;
|
|
1155
|
+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
|
|
1129
1156
|
} else {
|
|
1130
1157
|
s_algo = komputeManager()->getAlgorithm(__func__);
|
|
1131
1158
|
s_algo->setTensors({inA, inB, out});
|
|
1132
|
-
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
|
|
1159
|
+
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
|
|
1133
1160
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
|
1134
1161
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
|
1135
1162
|
}
|
|
@@ -1217,10 +1244,11 @@ static void ggml_vk_rope(
|
|
|
1217
1244
|
kp::Sequence& seq,
|
|
1218
1245
|
const std::shared_ptr<kp::Tensor>& inA,
|
|
1219
1246
|
const std::shared_ptr<kp::Tensor>& inB,
|
|
1247
|
+
const std::shared_ptr<kp::Tensor>& inC,
|
|
1220
1248
|
const std::shared_ptr<kp::Tensor>& out,
|
|
1221
|
-
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1249
|
+
uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
|
|
1222
1250
|
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
|
1223
|
-
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
|
1251
|
+
float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
|
1224
1252
|
int32_t ne01, int32_t ne02, int32_t ne03,
|
|
1225
1253
|
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1226
1254
|
int32_t ne0,
|
|
@@ -1228,11 +1256,17 @@ static void ggml_vk_rope(
|
|
|
1228
1256
|
) {
|
|
1229
1257
|
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
|
|
1230
1258
|
|
|
1231
|
-
static const auto
|
|
1232
|
-
kp::shader_data::
|
|
1259
|
+
static const auto spirv_norm_f16 = getSpirvShader(
|
|
1260
|
+
kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
|
|
1261
|
+
);
|
|
1262
|
+
static const auto spirv_norm_f32 = getSpirvShader(
|
|
1263
|
+
kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
|
|
1233
1264
|
);
|
|
1234
|
-
static const auto
|
|
1235
|
-
kp::shader_data::
|
|
1265
|
+
static const auto spirv_neox_f16 = getSpirvShader(
|
|
1266
|
+
kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
|
|
1267
|
+
);
|
|
1268
|
+
static const auto spirv_neox_f32 = getSpirvShader(
|
|
1269
|
+
kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
|
|
1236
1270
|
);
|
|
1237
1271
|
|
|
1238
1272
|
int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
|
|
@@ -1247,32 +1281,40 @@ static void ggml_vk_rope(
|
|
|
1247
1281
|
GGML_ASSERT(nb0 % type_size == 0);
|
|
1248
1282
|
|
|
1249
1283
|
struct PushConstants {
|
|
1250
|
-
uint32_t inAOff, inBOff, outOff;
|
|
1284
|
+
uint32_t inAOff, inBOff, inCOff, outOff;
|
|
1251
1285
|
int32_t n_dims, mode, n_ctx_orig;
|
|
1252
|
-
float freq_base, freq_scale
|
|
1286
|
+
float freq_base, freq_scale;
|
|
1287
|
+
bool has_freq_factors;
|
|
1288
|
+
float ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1253
1289
|
uint32_t nb00, nb01, nb02, nb03;
|
|
1254
1290
|
int32_t ne0;
|
|
1255
1291
|
uint32_t nb0, nb1, nb2, nb3;
|
|
1256
1292
|
} pushConsts {
|
|
1257
|
-
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
|
1293
|
+
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
|
|
1258
1294
|
n_dims, mode, n_ctx_orig,
|
|
1259
|
-
freq_base, freq_scale,
|
|
1295
|
+
freq_base, freq_scale,
|
|
1296
|
+
has_freq_factors,
|
|
1297
|
+
ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1260
1298
|
nb00, nb01, nb02, nb03,
|
|
1261
1299
|
ne0,
|
|
1262
1300
|
nb0, nb1, nb2, nb3
|
|
1263
1301
|
};
|
|
1264
1302
|
|
|
1265
|
-
auto
|
|
1303
|
+
auto & inC_ = inC ? inC : inA;
|
|
1304
|
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
1305
|
+
const bool is_f16 = src0t == GGML_TYPE_F16;
|
|
1306
|
+
|
|
1307
|
+
auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
|
|
1266
1308
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1267
1309
|
if (!komputeManager()->hasAlgorithm(name)) {
|
|
1310
|
+
auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
|
|
1268
1311
|
s_algo = komputeManager()->algorithm<float, PushConstants>(
|
|
1269
|
-
name, s_kompute_context->pool.get(), {inA, inB, out},
|
|
1270
|
-
src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
|
|
1312
|
+
name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
|
|
1271
1313
|
{unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
|
|
1272
1314
|
);
|
|
1273
1315
|
} else {
|
|
1274
1316
|
s_algo = komputeManager()->getAlgorithm(name);
|
|
1275
|
-
s_algo->setTensors({inA, inB, out});
|
|
1317
|
+
s_algo->setTensors({inA, inB, inC_, out});
|
|
1276
1318
|
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
|
|
1277
1319
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
|
1278
1320
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
|
@@ -1351,11 +1393,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
|
|
1351
1393
|
}
|
|
1352
1394
|
|
|
1353
1395
|
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
1396
|
+
int64_t n = ggml_nelements(op);
|
|
1354
1397
|
switch (op->op) {
|
|
1355
1398
|
case GGML_OP_UNARY:
|
|
1399
|
+
if (n % 4 != 0) return false;
|
|
1356
1400
|
switch (ggml_get_unary_op(op)) {
|
|
1357
|
-
case GGML_UNARY_OP_RELU:
|
|
1358
1401
|
case GGML_UNARY_OP_GELU:
|
|
1402
|
+
if (n % 8 != 0) return false;
|
|
1403
|
+
// fall through
|
|
1404
|
+
case GGML_UNARY_OP_RELU:
|
|
1359
1405
|
case GGML_UNARY_OP_SILU:
|
|
1360
1406
|
return ggml_is_contiguous(op->src[0]);
|
|
1361
1407
|
default:
|
|
@@ -1373,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
|
|
|
1373
1419
|
case GGML_OP_SOFT_MAX:
|
|
1374
1420
|
case GGML_OP_RMS_NORM:
|
|
1375
1421
|
case GGML_OP_NORM:
|
|
1376
|
-
case GGML_OP_ROPE:
|
|
1377
1422
|
return true;
|
|
1423
|
+
case GGML_OP_ROPE:
|
|
1424
|
+
{
|
|
1425
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
|
1426
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
1427
|
+
return false;
|
|
1428
|
+
}
|
|
1429
|
+
if (mode & GGML_ROPE_TYPE_VISION) {
|
|
1430
|
+
return false;
|
|
1431
|
+
}
|
|
1432
|
+
return true;
|
|
1433
|
+
}
|
|
1378
1434
|
case GGML_OP_DUP:
|
|
1379
1435
|
case GGML_OP_CPY:
|
|
1380
1436
|
case GGML_OP_CONT:
|
|
@@ -1413,8 +1469,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
|
|
|
1413
1469
|
|
|
1414
1470
|
switch (op->src[0]->type) {
|
|
1415
1471
|
case GGML_TYPE_F32:
|
|
1416
|
-
case GGML_TYPE_Q6_K:
|
|
1417
1472
|
return op->ne[3] == 1;
|
|
1473
|
+
case GGML_TYPE_Q6_K:
|
|
1418
1474
|
case GGML_TYPE_F16:
|
|
1419
1475
|
case GGML_TYPE_Q8_0:
|
|
1420
1476
|
case GGML_TYPE_Q4_0:
|
|
@@ -1515,9 +1571,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1515
1571
|
const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
|
|
1516
1572
|
uint32_t off_src0 = 0;
|
|
1517
1573
|
uint32_t off_src1 = 0;
|
|
1574
|
+
uint32_t off_src2 = 0;
|
|
1518
1575
|
uint32_t off_dst = 0;
|
|
1519
1576
|
const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
|
|
1520
1577
|
const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
|
|
1578
|
+
const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
|
|
1521
1579
|
const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
|
|
1522
1580
|
|
|
1523
1581
|
switch (dst->op) {
|
|
@@ -1593,11 +1651,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1593
1651
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
|
1594
1652
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
|
1595
1653
|
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1654
|
+
const int64_t nrows_x = ggml_nrows(src0);
|
|
1655
|
+
const int64_t nrows_y = src0->ne[1];
|
|
1656
|
+
|
|
1657
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
|
1658
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
1599
1659
|
|
|
1600
|
-
|
|
1660
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
1661
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1662
|
+
|
|
1663
|
+
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
|
|
1601
1664
|
} break;
|
|
1602
1665
|
case GGML_OP_DIAG_MASK_INF:
|
|
1603
1666
|
{
|
|
@@ -1649,38 +1712,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1649
1712
|
case GGML_TYPE_F16:
|
|
1650
1713
|
ggml_vk_mul_mat_f16(
|
|
1651
1714
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1652
|
-
ne00, ne01, ne02, nb00, nb01, nb02,
|
|
1715
|
+
ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
1716
|
+
ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
1653
1717
|
ne0, ne1, r2, r3
|
|
1654
1718
|
);
|
|
1655
1719
|
break;
|
|
1656
1720
|
case GGML_TYPE_Q8_0:
|
|
1657
1721
|
ggml_vk_mul_mat_q8_0(
|
|
1658
1722
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1659
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1723
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1724
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1660
1725
|
);
|
|
1661
1726
|
break;
|
|
1662
1727
|
case GGML_TYPE_Q4_0:
|
|
1663
1728
|
ggml_vk_mul_mat_q4_0(
|
|
1664
1729
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1665
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1730
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1731
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1666
1732
|
);
|
|
1667
1733
|
break;
|
|
1668
1734
|
case GGML_TYPE_Q4_1:
|
|
1669
1735
|
ggml_vk_mul_mat_q4_1(
|
|
1670
1736
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1671
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1737
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1738
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1672
1739
|
);
|
|
1673
1740
|
break;
|
|
1674
1741
|
case GGML_TYPE_Q4_K:
|
|
1675
1742
|
ggml_vk_mul_mat_q4_k(
|
|
1676
1743
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1677
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1744
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1745
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1678
1746
|
);
|
|
1679
1747
|
break;
|
|
1680
1748
|
case GGML_TYPE_Q6_K:
|
|
1681
1749
|
ggml_vk_mul_mat_q6_k(
|
|
1682
1750
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1683
|
-
ne00, ne10,
|
|
1751
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1752
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1684
1753
|
);
|
|
1685
1754
|
break;
|
|
1686
1755
|
default: {
|
|
@@ -1709,13 +1778,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1709
1778
|
} break;
|
|
1710
1779
|
case GGML_OP_ROPE:
|
|
1711
1780
|
{
|
|
1712
|
-
#pragma message("TODO: implement phi3 frequency factors support")
|
|
1713
|
-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
|
1714
|
-
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
|
1715
|
-
|
|
1716
|
-
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
|
1717
|
-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
|
1718
|
-
|
|
1719
1781
|
GGML_ASSERT(ne10 == ne02);
|
|
1720
1782
|
GGML_ASSERT(src0t == dstt);
|
|
1721
1783
|
// const int n_past = ((int32_t *) dst->op_params)[0];
|
|
@@ -1724,6 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1724
1786
|
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
|
1725
1787
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
1726
1788
|
|
|
1789
|
+
const bool has_freq_factors = dst->src[2] != nullptr;
|
|
1790
|
+
|
|
1727
1791
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1728
1792
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1729
1793
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
@@ -1732,8 +1796,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1732
1796
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1733
1797
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1734
1798
|
ggml_vk_rope(
|
|
1735
|
-
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
|
1736
|
-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1799
|
+
seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
|
1800
|
+
freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1737
1801
|
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
|
1738
1802
|
);
|
|
1739
1803
|
} break;
|
|
@@ -2176,9 +2240,12 @@ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
|
|
|
2176
2240
|
|
|
2177
2241
|
ggml_backend_reg_t ggml_backend_kompute_reg() {
|
|
2178
2242
|
static ggml_backend_reg reg = {
|
|
2179
|
-
/* .
|
|
2180
|
-
/* .
|
|
2243
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
2244
|
+
/* .iface = */ ggml_backend_kompute_reg_i,
|
|
2245
|
+
/* .context = */ nullptr,
|
|
2181
2246
|
};
|
|
2182
2247
|
|
|
2183
2248
|
return ®
|
|
2184
2249
|
}
|
|
2250
|
+
|
|
2251
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg)
|
|
@@ -4,19 +4,16 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
|
|
4
4
|
|
|
5
5
|
message(STATUS "Metal framework found")
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
ggml_add_backend_library(ggml-metal
|
|
8
|
+
ggml-metal.m
|
|
9
|
+
)
|
|
10
10
|
|
|
11
11
|
target_link_libraries(ggml-metal PRIVATE
|
|
12
|
-
ggml-base
|
|
13
12
|
${FOUNDATION_LIBRARY}
|
|
14
13
|
${METAL_FRAMEWORK}
|
|
15
14
|
${METALKIT_FRAMEWORK}
|
|
16
15
|
)
|
|
17
16
|
|
|
18
|
-
target_include_directories(ggml-metal PRIVATE . ..)
|
|
19
|
-
|
|
20
17
|
if (GGML_METAL_NDEBUG)
|
|
21
18
|
add_compile_definitions(GGML_METAL_NDEBUG)
|
|
22
19
|
endif()
|
|
@@ -102,6 +102,21 @@ typedef struct {
|
|
|
102
102
|
uint64_t nb3;
|
|
103
103
|
} ggml_metal_kargs_cpy;
|
|
104
104
|
|
|
105
|
+
typedef struct {
|
|
106
|
+
int64_t ne10;
|
|
107
|
+
int64_t ne11;
|
|
108
|
+
int64_t ne12;
|
|
109
|
+
uint64_t nb10;
|
|
110
|
+
uint64_t nb11;
|
|
111
|
+
uint64_t nb12;
|
|
112
|
+
uint64_t nb13;
|
|
113
|
+
uint64_t nb1;
|
|
114
|
+
uint64_t nb2;
|
|
115
|
+
uint64_t nb3;
|
|
116
|
+
uint64_t offs;
|
|
117
|
+
bool inplace;
|
|
118
|
+
} ggml_metal_kargs_set;
|
|
119
|
+
|
|
105
120
|
typedef struct {
|
|
106
121
|
int32_t ne00;
|
|
107
122
|
int32_t ne01;
|
|
@@ -192,6 +207,30 @@ typedef struct {
|
|
|
192
207
|
int16_t r3;
|
|
193
208
|
} ggml_metal_kargs_mul_mv;
|
|
194
209
|
|
|
210
|
+
typedef struct {
|
|
211
|
+
int32_t ne00;
|
|
212
|
+
int32_t ne01;
|
|
213
|
+
int32_t ne02;
|
|
214
|
+
uint64_t nb00;
|
|
215
|
+
uint64_t nb01;
|
|
216
|
+
uint64_t nb02;
|
|
217
|
+
uint64_t nb03;
|
|
218
|
+
int32_t ne10;
|
|
219
|
+
int32_t ne11;
|
|
220
|
+
int32_t ne12;
|
|
221
|
+
uint64_t nb10;
|
|
222
|
+
uint64_t nb11;
|
|
223
|
+
uint64_t nb12;
|
|
224
|
+
uint64_t nb13;
|
|
225
|
+
int32_t ne0;
|
|
226
|
+
int32_t ne1;
|
|
227
|
+
int16_t r2;
|
|
228
|
+
int16_t r3;
|
|
229
|
+
int16_t nsg;
|
|
230
|
+
int16_t nxpsg;
|
|
231
|
+
int16_t r1ptg;
|
|
232
|
+
} ggml_metal_kargs_mul_mv_ext;
|
|
233
|
+
|
|
195
234
|
typedef struct {
|
|
196
235
|
int32_t nei0;
|
|
197
236
|
int32_t nei1;
|
|
@@ -20,6 +20,11 @@ find_package(MUSAToolkit)
|
|
|
20
20
|
if (MUSAToolkit_FOUND)
|
|
21
21
|
message(STATUS "MUSA Toolkit found")
|
|
22
22
|
|
|
23
|
+
if (NOT DEFINED MUSA_ARCHITECTURES)
|
|
24
|
+
set(MUSA_ARCHITECTURES "21;22")
|
|
25
|
+
endif()
|
|
26
|
+
message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}")
|
|
27
|
+
|
|
23
28
|
file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
|
|
24
29
|
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
|
|
25
30
|
|
|
@@ -44,15 +49,17 @@ if (MUSAToolkit_FOUND)
|
|
|
44
49
|
|
|
45
50
|
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
|
|
46
51
|
foreach(SOURCE ${GGML_SOURCES_MUSA})
|
|
47
|
-
|
|
52
|
+
set(COMPILE_FLAGS "-x musa -mtgpu")
|
|
53
|
+
foreach(ARCH ${MUSA_ARCHITECTURES})
|
|
54
|
+
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
|
|
55
|
+
endforeach()
|
|
56
|
+
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS})
|
|
48
57
|
endforeach()
|
|
49
58
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
target_link_libraries(ggml-musa PRIVATE ggml-base)
|
|
55
|
-
target_include_directories(ggml-musa PRIVATE . ..)
|
|
59
|
+
ggml_add_backend_library(ggml-musa
|
|
60
|
+
${GGML_HEADERS_MUSA}
|
|
61
|
+
${GGML_SOURCES_MUSA}
|
|
62
|
+
)
|
|
56
63
|
|
|
57
64
|
# TODO: do not use CUDA definitions for MUSA
|
|
58
65
|
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
|