@fugood/llama.node 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +1 -1
- package/src/LlamaContext.cpp +81 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
#include <ggml.h>
|
|
19
|
-
#include <ggml-cpu.h>
|
|
20
19
|
#include <ggml-alloc.h>
|
|
21
20
|
#include <ggml-backend.h>
|
|
22
21
|
|
|
@@ -26,7 +25,6 @@
|
|
|
26
25
|
#include <cstdint>
|
|
27
26
|
#include <cstring>
|
|
28
27
|
#include <cinttypes>
|
|
29
|
-
#include <functional>
|
|
30
28
|
#include <memory>
|
|
31
29
|
#include <random>
|
|
32
30
|
#include <stdio.h>
|
|
@@ -639,19 +637,20 @@ struct test_case {
|
|
|
639
637
|
|
|
640
638
|
// determine number of runs
|
|
641
639
|
int n_runs;
|
|
640
|
+
bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
|
|
642
641
|
if (op_flops(out) > 0) {
|
|
643
642
|
// based on flops
|
|
644
643
|
const uint64_t GFLOP = 1000 * 1000 * 1000;
|
|
645
644
|
const uint64_t target_flops_cpu = 8ULL * GFLOP;
|
|
646
645
|
const uint64_t target_flops_gpu = 100ULL * GFLOP;
|
|
647
|
-
uint64_t target_flops =
|
|
646
|
+
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
|
|
648
647
|
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
|
|
649
648
|
} else {
|
|
650
649
|
// based on memory size
|
|
651
650
|
const size_t GB = 1ULL << 30;
|
|
652
651
|
const size_t target_size_cpu = 8 * GB;
|
|
653
652
|
const size_t target_size_gpu = 32 * GB;
|
|
654
|
-
size_t target_size =
|
|
653
|
+
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
|
|
655
654
|
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
|
|
656
655
|
}
|
|
657
656
|
|
|
@@ -819,7 +818,6 @@ struct test_case {
|
|
|
819
818
|
}
|
|
820
819
|
}
|
|
821
820
|
|
|
822
|
-
// TODO: refactor so that this check is only needed once
|
|
823
821
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
824
822
|
if (!ggml_backend_supports_op(backend, t)) {
|
|
825
823
|
printf("not supported [%s] ", ggml_backend_name(backend));
|
|
@@ -1155,6 +1153,26 @@ struct test_argmax : public test_case {
|
|
|
1155
1153
|
return out;
|
|
1156
1154
|
}
|
|
1157
1155
|
|
|
1156
|
+
void initialize_tensors(ggml_context * ctx) override {
|
|
1157
|
+
std::random_device rd;
|
|
1158
|
+
std::default_random_engine rng(rd());
|
|
1159
|
+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
1160
|
+
if (t->type == GGML_TYPE_F32) {
|
|
1161
|
+
// initialize with unique values to avoid ties
|
|
1162
|
+
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
|
1163
|
+
std::vector<float> data(t->ne[0]);
|
|
1164
|
+
for (int i = 0; i < t->ne[0]; i++) {
|
|
1165
|
+
data[i] = i;
|
|
1166
|
+
}
|
|
1167
|
+
std::shuffle(data.begin(), data.end(), rng);
|
|
1168
|
+
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
|
1169
|
+
}
|
|
1170
|
+
} else {
|
|
1171
|
+
init_tensor_uniform(t);
|
|
1172
|
+
}
|
|
1173
|
+
}
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1158
1176
|
double max_nmse_err() override {
|
|
1159
1177
|
return 0.0;
|
|
1160
1178
|
}
|
|
@@ -2183,7 +2201,15 @@ struct test_rope : public test_case {
|
|
|
2183
2201
|
ggml_set_name(a, "a");
|
|
2184
2202
|
}
|
|
2185
2203
|
|
|
2186
|
-
|
|
2204
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
2205
|
+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
2206
|
+
|
|
2207
|
+
ggml_tensor * pos;
|
|
2208
|
+
if (is_mrope || is_vision) {
|
|
2209
|
+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
|
|
2210
|
+
} else {
|
|
2211
|
+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
|
|
2212
|
+
}
|
|
2187
2213
|
ggml_set_name(pos, "pos");
|
|
2188
2214
|
|
|
2189
2215
|
ggml_tensor * freq = nullptr;
|
|
@@ -2192,7 +2218,20 @@ struct test_rope : public test_case {
|
|
|
2192
2218
|
ggml_set_name(freq, "freq");
|
|
2193
2219
|
}
|
|
2194
2220
|
|
|
2195
|
-
ggml_tensor * out
|
|
2221
|
+
ggml_tensor * out;
|
|
2222
|
+
if (is_mrope) {
|
|
2223
|
+
if (is_vision) {
|
|
2224
|
+
GGML_ASSERT(n_dims/4 > 0);
|
|
2225
|
+
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
|
2226
|
+
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
|
2227
|
+
} else {
|
|
2228
|
+
GGML_ASSERT(n_dims/3 > 0);
|
|
2229
|
+
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
|
2230
|
+
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
|
2231
|
+
}
|
|
2232
|
+
} else {
|
|
2233
|
+
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
|
2234
|
+
}
|
|
2196
2235
|
ggml_set_name(out, "out");
|
|
2197
2236
|
|
|
2198
2237
|
return out;
|
|
@@ -2202,11 +2241,12 @@ struct test_rope : public test_case {
|
|
|
2202
2241
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
2203
2242
|
if (t->type == GGML_TYPE_I32) {
|
|
2204
2243
|
// pos
|
|
2205
|
-
|
|
2206
|
-
|
|
2244
|
+
const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
|
|
2245
|
+
std::vector<int> data(num_pos_ids);
|
|
2246
|
+
for (int i = 0; i < num_pos_ids; i++) {
|
|
2207
2247
|
data[i] = rand() % n_ctx;
|
|
2208
2248
|
}
|
|
2209
|
-
ggml_backend_tensor_set(t, data.data(), 0,
|
|
2249
|
+
ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
|
|
2210
2250
|
} else {
|
|
2211
2251
|
if (t->ne[0] == n_dims/2) {
|
|
2212
2252
|
// frequency factors in the range [0.9f, 1.1f]
|
|
@@ -2679,6 +2719,33 @@ struct test_pad : public test_case {
|
|
|
2679
2719
|
}
|
|
2680
2720
|
};
|
|
2681
2721
|
|
|
2722
|
+
// GGML_OP_PAD_REFLECT_1D
|
|
2723
|
+
struct test_pad_reflect_1d : public test_case {
|
|
2724
|
+
const ggml_type type;
|
|
2725
|
+
const std::array<int64_t, 4> ne_a;
|
|
2726
|
+
const int pad_0;
|
|
2727
|
+
const int pad_1;
|
|
2728
|
+
|
|
2729
|
+
std::string vars() override {
|
|
2730
|
+
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
|
2731
|
+
}
|
|
2732
|
+
|
|
2733
|
+
test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
|
|
2734
|
+
std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
|
|
2735
|
+
int pad_0 = 10, int pad_1 = 9)
|
|
2736
|
+
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
|
2737
|
+
|
|
2738
|
+
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
2739
|
+
ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
|
|
2740
|
+
ggml_set_name(a, "a");
|
|
2741
|
+
|
|
2742
|
+
ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
|
|
2743
|
+
ggml_set_name(out, "out");
|
|
2744
|
+
|
|
2745
|
+
return out;
|
|
2746
|
+
}
|
|
2747
|
+
};
|
|
2748
|
+
|
|
2682
2749
|
// GGML_OP_ARANGE
|
|
2683
2750
|
struct test_arange : public test_case {
|
|
2684
2751
|
const ggml_type type;
|
|
@@ -3316,7 +3383,9 @@ static const ggml_type all_types[] = {
|
|
|
3316
3383
|
|
|
3317
3384
|
static const ggml_type base_types[] = {
|
|
3318
3385
|
GGML_TYPE_F32, GGML_TYPE_F16,
|
|
3386
|
+
GGML_TYPE_Q8_0, // for I8MM tests
|
|
3319
3387
|
GGML_TYPE_Q4_0,
|
|
3388
|
+
GGML_TYPE_Q4_1, // for I8MM tests
|
|
3320
3389
|
GGML_TYPE_Q4_K,
|
|
3321
3390
|
GGML_TYPE_IQ2_XXS
|
|
3322
3391
|
};
|
|
@@ -3440,9 +3509,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3440
3509
|
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
|
3441
3510
|
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
|
3442
3511
|
|
|
3443
|
-
test_cases.emplace_back(new test_argmax());
|
|
3444
3512
|
test_cases.emplace_back(new test_count_equal());
|
|
3445
3513
|
|
|
3514
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
|
3515
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
|
3516
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
|
3517
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
|
|
3518
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
|
3519
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
|
|
3520
|
+
|
|
3446
3521
|
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
|
3447
3522
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
|
|
3448
3523
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
|
|
@@ -3468,10 +3543,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3468
3543
|
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
|
|
3469
3544
|
}
|
|
3470
3545
|
|
|
3546
|
+
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
|
3547
|
+
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
|
|
3548
|
+
}
|
|
3549
|
+
|
|
3471
3550
|
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
|
3472
3551
|
for (ggml_type type_dst : all_types) {
|
|
3473
|
-
|
|
3474
|
-
|
|
3552
|
+
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
|
3553
|
+
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
|
|
3475
3554
|
}
|
|
3476
3555
|
}
|
|
3477
3556
|
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
|
@@ -3547,6 +3626,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3547
3626
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
|
3548
3627
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
|
3549
3628
|
|
|
3629
|
+
for (int i = 1; i < 9; ++i) {
|
|
3630
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3631
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3632
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3633
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3634
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3635
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3636
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3637
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3638
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3639
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
|
3640
|
+
}
|
|
3641
|
+
|
|
3550
3642
|
#if 1
|
|
3551
3643
|
for (ggml_type type_a : base_types) {
|
|
3552
3644
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
|
@@ -3743,6 +3835,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3743
3835
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
|
|
3744
3836
|
}
|
|
3745
3837
|
|
|
3838
|
+
if (all) {
|
|
3839
|
+
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
|
|
3840
|
+
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
|
|
3841
|
+
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
|
|
3842
|
+
}
|
|
3843
|
+
|
|
3746
3844
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
|
|
3747
3845
|
}
|
|
3748
3846
|
}
|
|
@@ -3773,9 +3871,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3773
3871
|
test_cases.emplace_back(new test_upscale());
|
|
3774
3872
|
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
|
|
3775
3873
|
test_cases.emplace_back(new test_upscale_ext());
|
|
3776
|
-
test_cases.emplace_back(new test_group_norm());
|
|
3874
|
+
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
|
3875
|
+
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
|
3777
3876
|
test_cases.emplace_back(new test_acc());
|
|
3778
3877
|
test_cases.emplace_back(new test_pad());
|
|
3878
|
+
test_cases.emplace_back(new test_pad_reflect_1d());
|
|
3779
3879
|
test_cases.emplace_back(new test_arange());
|
|
3780
3880
|
test_cases.emplace_back(new test_timestep_embedding());
|
|
3781
3881
|
test_cases.emplace_back(new test_leaky_relu());
|
|
@@ -3822,6 +3922,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
3822
3922
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
|
3823
3923
|
|
|
3824
3924
|
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
|
|
3925
|
+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
|
3926
|
+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
|
3927
|
+
|
|
3928
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
|
|
3929
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
|
|
3930
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
|
|
3931
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
|
|
3932
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
|
|
3933
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
|
|
3934
|
+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
|
|
3935
|
+
|
|
3936
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
|
|
3937
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
|
3938
|
+
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
|
|
3825
3939
|
|
|
3826
3940
|
for (int bs : {1, 512}) {
|
|
3827
3941
|
for (ggml_type type_a : all_types) {
|
|
@@ -3837,7 +3951,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
3837
3951
|
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
|
3838
3952
|
if (mode == MODE_TEST) {
|
|
3839
3953
|
auto test_cases = make_test_cases_eval();
|
|
3840
|
-
ggml_backend_t backend_cpu =
|
|
3954
|
+
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
|
|
3955
|
+
if (backend_cpu == NULL) {
|
|
3956
|
+
printf(" Failed to initialize CPU backend\n");
|
|
3957
|
+
return false;
|
|
3958
|
+
}
|
|
3841
3959
|
|
|
3842
3960
|
size_t n_ok = 0;
|
|
3843
3961
|
for (auto & test : test_cases) {
|
|
@@ -3917,7 +4035,9 @@ int main(int argc, char ** argv) {
|
|
|
3917
4035
|
}
|
|
3918
4036
|
}
|
|
3919
4037
|
|
|
3920
|
-
// enumerate backends
|
|
4038
|
+
// load and enumerate backends
|
|
4039
|
+
ggml_backend_load_all();
|
|
4040
|
+
|
|
3921
4041
|
printf("Testing %zu devices\n\n", ggml_backend_dev_count());
|
|
3922
4042
|
|
|
3923
4043
|
size_t n_ok = 0;
|
|
@@ -3933,16 +4053,15 @@ int main(int argc, char ** argv) {
|
|
|
3933
4053
|
continue;
|
|
3934
4054
|
}
|
|
3935
4055
|
|
|
3936
|
-
|
|
3937
|
-
GGML_ASSERT(backend != NULL);
|
|
3938
|
-
|
|
3939
|
-
if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) {
|
|
4056
|
+
if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
|
|
3940
4057
|
printf(" Skipping CPU backend\n");
|
|
3941
|
-
ggml_backend_free(backend);
|
|
3942
4058
|
n_ok++;
|
|
3943
4059
|
continue;
|
|
3944
4060
|
}
|
|
3945
4061
|
|
|
4062
|
+
ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
|
|
4063
|
+
GGML_ASSERT(backend != NULL);
|
|
4064
|
+
|
|
3946
4065
|
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
|
3947
4066
|
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
|
3948
4067
|
if (ggml_backend_set_n_threads_fn) {
|
|
@@ -21,7 +21,7 @@ int main(void) {
|
|
|
21
21
|
std::vector<std::string> templates = {
|
|
22
22
|
// teknium/OpenHermes-2.5-Mistral-7B
|
|
23
23
|
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
|
24
|
-
// mistralai/Mistral-7B-Instruct-v0.2
|
|
24
|
+
// mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
|
|
25
25
|
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
|
26
26
|
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
|
27
27
|
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
|
|
@@ -67,16 +67,26 @@ int main(void) {
|
|
|
67
67
|
"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
|
68
68
|
// ibm-granite/granite-3.0-8b-instruct
|
|
69
69
|
"{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
|
|
70
|
+
// mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
|
|
71
|
+
"{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
|
|
72
|
+
// Mistral-Large-Instruct-2407 (mistralai 'v3' template)
|
|
73
|
+
"{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
|
74
|
+
// Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template)
|
|
75
|
+
"{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
|
|
76
|
+
// mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
|
|
77
|
+
"{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
|
78
|
+
// ai-sage/GigaChat-20B-A3B-instruct
|
|
79
|
+
"{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
|
|
70
80
|
};
|
|
71
81
|
std::vector<std::string> expected_output = {
|
|
72
82
|
// teknium/OpenHermes-2.5-Mistral-7B
|
|
73
83
|
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
|
|
74
|
-
// mistralai/Mistral-7B-Instruct-v0.2
|
|
84
|
+
// mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
|
|
75
85
|
"[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
|
76
86
|
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
|
77
|
-
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]
|
|
87
|
+
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
|
78
88
|
// bofenghuang/vigogne-2-70b-chat
|
|
79
|
-
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]
|
|
89
|
+
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
|
|
80
90
|
// mlabonne/AlphaMonarch-7B
|
|
81
91
|
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
|
82
92
|
// google/gemma-7b-it
|
|
@@ -113,10 +123,31 @@ int main(void) {
|
|
|
113
123
|
u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
|
|
114
124
|
// ibm-granite/granite-3.0-8b-instruct
|
|
115
125
|
"<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
|
|
126
|
+
// mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
|
|
127
|
+
" [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
|
|
128
|
+
// Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)
|
|
129
|
+
"[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
|
|
130
|
+
// Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)
|
|
131
|
+
"[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
|
|
132
|
+
// mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
|
|
133
|
+
"[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
|
|
134
|
+
// ai-sage/GigaChat-20B-A3B-instruct
|
|
135
|
+
"<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
|
|
116
136
|
};
|
|
117
137
|
std::vector<char> formatted_chat(1024);
|
|
118
138
|
int32_t res;
|
|
119
139
|
|
|
140
|
+
// list all supported templates
|
|
141
|
+
std::vector<const char *> supported_tmpl;
|
|
142
|
+
res = llama_chat_builtin_templates(nullptr, 0);
|
|
143
|
+
assert(res > 0);
|
|
144
|
+
supported_tmpl.resize(res);
|
|
145
|
+
res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
|
|
146
|
+
printf("Built-in chat templates:\n");
|
|
147
|
+
for (auto tmpl : supported_tmpl) {
|
|
148
|
+
printf(" %s\n", tmpl);
|
|
149
|
+
}
|
|
150
|
+
|
|
120
151
|
// test invalid chat template
|
|
121
152
|
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
|
122
153
|
assert(res < 0);
|
|
@@ -154,9 +185,16 @@ int main(void) {
|
|
|
154
185
|
return output;
|
|
155
186
|
};
|
|
156
187
|
assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
|
|
188
|
+
assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
|
|
189
|
+
assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
|
|
190
|
+
assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
|
|
191
|
+
assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
|
|
157
192
|
assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
|
|
193
|
+
assert(fmt_sys("llama2-sys") == "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\n");
|
|
194
|
+
assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
|
|
158
195
|
assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
|
|
159
196
|
assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
|
|
197
|
+
assert(fmt_sys("gigachat") == "<s>You are a helpful assistant<|message_sep|>");
|
|
160
198
|
|
|
161
199
|
|
|
162
200
|
// test llama_chat_format_single for user message
|
|
@@ -173,9 +211,17 @@ int main(void) {
|
|
|
173
211
|
return output;
|
|
174
212
|
};
|
|
175
213
|
assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
|
|
214
|
+
assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
|
|
215
|
+
assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
|
|
216
|
+
assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
|
|
217
|
+
assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]");
|
|
176
218
|
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
|
|
219
|
+
assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
|
|
177
220
|
assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
|
|
178
221
|
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
|
|
222
|
+
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
|
|
223
|
+
|
|
224
|
+
printf("Test chat templates: OK\n");
|
|
179
225
|
|
|
180
226
|
return 0;
|
|
181
227
|
}
|