@fugood/llama.node 0.3.16 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +238 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +6 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +160 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1760 -534
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
|
|
|
271
271
|
}
|
|
272
272
|
}
|
|
273
273
|
|
|
274
|
+
static std::string var_to_str(ggml_scale_mode mode) {
|
|
275
|
+
switch (mode) {
|
|
276
|
+
case GGML_SCALE_MODE_NEAREST: return "nearest";
|
|
277
|
+
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
|
|
278
|
+
default: return std::to_string(mode);
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
274
282
|
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
|
|
275
283
|
|
|
276
284
|
#define VARS_TO_STR1(a) VAR_TO_STR(a)
|
|
@@ -2063,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
|
|
|
2063
2071
|
const ggml_type type_b;
|
|
2064
2072
|
const int n_mats;
|
|
2065
2073
|
const int n_used;
|
|
2066
|
-
const bool b; //
|
|
2074
|
+
const bool b; // broadcast b matrix
|
|
2067
2075
|
const int64_t m;
|
|
2068
2076
|
const int64_t n;
|
|
2069
2077
|
const int64_t k;
|
|
@@ -2598,6 +2606,8 @@ struct test_rope : public test_case {
|
|
|
2598
2606
|
} else {
|
|
2599
2607
|
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
|
2600
2608
|
}
|
|
2609
|
+
|
|
2610
|
+
// TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
|
|
2601
2611
|
}
|
|
2602
2612
|
ggml_set_name(out, "out");
|
|
2603
2613
|
|
|
@@ -2948,15 +2958,16 @@ struct test_upscale : public test_case {
|
|
|
2948
2958
|
const std::array<int64_t, 4> ne;
|
|
2949
2959
|
const int32_t scale_factor;
|
|
2950
2960
|
const bool transpose;
|
|
2961
|
+
const ggml_scale_mode mode;
|
|
2951
2962
|
|
|
2952
2963
|
std::string vars() override {
|
|
2953
|
-
return
|
|
2964
|
+
return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
|
|
2954
2965
|
}
|
|
2955
2966
|
|
|
2956
2967
|
test_upscale(ggml_type type = GGML_TYPE_F32,
|
|
2957
2968
|
std::array<int64_t, 4> ne = {512, 512, 3, 1},
|
|
2958
|
-
int32_t scale_factor = 2, bool transpose = false)
|
|
2959
|
-
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
|
|
2969
|
+
int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
|
|
2970
|
+
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
|
|
2960
2971
|
|
|
2961
2972
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
2962
2973
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
@@ -2967,7 +2978,7 @@ struct test_upscale : public test_case {
|
|
|
2967
2978
|
ggml_set_name(a, "a_transposed");
|
|
2968
2979
|
}
|
|
2969
2980
|
|
|
2970
|
-
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
|
|
2981
|
+
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
|
|
2971
2982
|
ggml_set_name(out, "out");
|
|
2972
2983
|
|
|
2973
2984
|
return out;
|
|
@@ -2979,21 +2990,23 @@ struct test_upscale_ext : public test_case {
|
|
|
2979
2990
|
const ggml_type type;
|
|
2980
2991
|
const std::array<int64_t, 4> ne;
|
|
2981
2992
|
const std::array<int64_t, 4> ne_tgt;
|
|
2993
|
+
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
|
|
2982
2994
|
|
|
2983
2995
|
std::string vars() override {
|
|
2984
|
-
return
|
|
2996
|
+
return VARS_TO_STR4(type, ne, ne_tgt, mode);
|
|
2985
2997
|
}
|
|
2986
2998
|
|
|
2987
2999
|
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
|
|
2988
3000
|
std::array<int64_t, 4> ne = {2, 5, 7, 11},
|
|
2989
|
-
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13}
|
|
2990
|
-
|
|
3001
|
+
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
|
|
3002
|
+
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
|
|
3003
|
+
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
|
|
2991
3004
|
|
|
2992
3005
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
2993
3006
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
2994
3007
|
ggml_set_name(a, "a");
|
|
2995
3008
|
|
|
2996
|
-
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
|
|
3009
|
+
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
|
|
2997
3010
|
ggml_set_name(out, "out");
|
|
2998
3011
|
|
|
2999
3012
|
return out;
|
|
@@ -3217,7 +3230,8 @@ struct test_leaky_relu : public test_case {
|
|
|
3217
3230
|
|
|
3218
3231
|
// GGML_OP_FLASH_ATTN_EXT
|
|
3219
3232
|
struct test_flash_attn_ext : public test_case {
|
|
3220
|
-
const int64_t
|
|
3233
|
+
const int64_t hsk; // K head size
|
|
3234
|
+
const int64_t hsv; // V head size
|
|
3221
3235
|
const int64_t nh; // num heads
|
|
3222
3236
|
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
|
3223
3237
|
const int64_t kv; // kv size
|
|
@@ -3233,7 +3247,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3233
3247
|
std::array<int32_t, 4> permute;
|
|
3234
3248
|
|
|
3235
3249
|
std::string vars() override {
|
|
3236
|
-
return
|
|
3250
|
+
return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
|
3237
3251
|
}
|
|
3238
3252
|
|
|
3239
3253
|
double max_nmse_err() override {
|
|
@@ -3243,17 +3257,18 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3243
3257
|
uint64_t op_flops(ggml_tensor * t) override {
|
|
3244
3258
|
GGML_UNUSED(t);
|
|
3245
3259
|
// Just counting matmul costs:
|
|
3246
|
-
// Q*K^T is nb x
|
|
3247
|
-
return 2 *
|
|
3260
|
+
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
|
|
3261
|
+
return 2 * nh*nr * nb * (hsk + hsv) * kv;
|
|
3248
3262
|
}
|
|
3249
3263
|
|
|
3250
|
-
test_flash_attn_ext(int64_t
|
|
3264
|
+
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
|
3251
3265
|
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
|
3252
3266
|
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
|
3253
|
-
:
|
|
3267
|
+
: hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
|
3254
3268
|
|
|
3255
3269
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
3256
|
-
const int64_t
|
|
3270
|
+
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
|
|
3271
|
+
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
|
|
3257
3272
|
|
|
3258
3273
|
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
|
|
3259
3274
|
int64_t ne[4] = {ne0, ne1, ne2, ne3};
|
|
@@ -3268,13 +3283,13 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3268
3283
|
return t;
|
|
3269
3284
|
};
|
|
3270
3285
|
|
|
3271
|
-
ggml_tensor * q = create_permuted(GGML_TYPE_F32,
|
|
3286
|
+
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
|
|
3272
3287
|
ggml_set_name(q, "q");
|
|
3273
3288
|
|
|
3274
|
-
ggml_tensor * k = create_permuted(type_KV,
|
|
3289
|
+
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
|
|
3275
3290
|
ggml_set_name(k, "k");
|
|
3276
3291
|
|
|
3277
|
-
ggml_tensor * v = create_permuted(type_KV,
|
|
3292
|
+
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
|
|
3278
3293
|
ggml_set_name(v, "v");
|
|
3279
3294
|
|
|
3280
3295
|
ggml_tensor * m = nullptr;
|
|
@@ -3283,7 +3298,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3283
3298
|
ggml_set_name(m, "m");
|
|
3284
3299
|
}
|
|
3285
3300
|
|
|
3286
|
-
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(
|
|
3301
|
+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
|
|
3287
3302
|
ggml_flash_attn_ext_set_prec(out, prec);
|
|
3288
3303
|
ggml_set_name(out, "out");
|
|
3289
3304
|
|
|
@@ -4169,6 +4184,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4169
4184
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
|
4170
4185
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
|
4171
4186
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
|
4187
|
+
|
|
4188
|
+
// test cases with large ne00/ne10 to cover stream-k fixup
|
|
4189
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
|
|
4190
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
|
|
4191
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
|
|
4172
4192
|
}
|
|
4173
4193
|
}
|
|
4174
4194
|
for (ggml_type type_a : other_types) {
|
|
@@ -4204,6 +4224,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4204
4224
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
|
|
4205
4225
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
|
|
4206
4226
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
|
|
4227
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
|
4228
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
|
4207
4229
|
|
|
4208
4230
|
for (auto bs : {1,2,4,8}) {
|
|
4209
4231
|
for (auto nr : {1,4}) {
|
|
@@ -4395,12 +4417,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4395
4417
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
|
4396
4418
|
}
|
|
4397
4419
|
|
|
4420
|
+
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
|
4421
|
+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
|
4422
|
+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
|
4423
|
+
test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
|
|
4424
|
+
}
|
|
4425
|
+
|
|
4398
4426
|
test_cases.emplace_back(new test_sum());
|
|
4399
4427
|
test_cases.emplace_back(new test_sum_rows());
|
|
4400
4428
|
test_cases.emplace_back(new test_mean());
|
|
4401
|
-
test_cases.emplace_back(new test_upscale());
|
|
4402
|
-
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
|
|
4403
|
-
test_cases.emplace_back(new test_upscale_ext());
|
|
4404
4429
|
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
|
4405
4430
|
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
|
4406
4431
|
test_cases.emplace_back(new test_acc());
|
|
@@ -4410,27 +4435,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4410
4435
|
test_cases.emplace_back(new test_timestep_embedding());
|
|
4411
4436
|
test_cases.emplace_back(new test_leaky_relu());
|
|
4412
4437
|
|
|
4413
|
-
for (int
|
|
4414
|
-
for (
|
|
4415
|
-
|
|
4416
|
-
|
|
4417
|
-
|
|
4418
|
-
|
|
4419
|
-
|
|
4420
|
-
|
|
4421
|
-
|
|
4422
|
-
|
|
4423
|
-
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
|
|
4427
|
-
|
|
4428
|
-
|
|
4429
|
-
|
|
4430
|
-
|
|
4431
|
-
if (
|
|
4438
|
+
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
|
|
4439
|
+
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
|
|
4440
|
+
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
|
4441
|
+
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
|
4442
|
+
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
|
4443
|
+
|
|
4444
|
+
for (bool mask : { true, false } ) {
|
|
4445
|
+
for (float max_bias : { 0.0f, 8.0f }) {
|
|
4446
|
+
if (!mask && max_bias > 0.0f) continue;
|
|
4447
|
+
for (float logit_softcap : {0.0f, 10.0f}) {
|
|
4448
|
+
if (hsk != 128 && logit_softcap != 0.0f) continue;
|
|
4449
|
+
for (int nh : { 4, }) {
|
|
4450
|
+
for (int nr : { 1, 4, 16 }) {
|
|
4451
|
+
if (nr == 16 && hsk != 128) continue;
|
|
4452
|
+
for (int kv : { 512, 1024, }) {
|
|
4453
|
+
if (nr != 1 && kv != 512) continue;
|
|
4454
|
+
for (int nb : { 1, 3, 32, 35, }) {
|
|
4455
|
+
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
|
4456
|
+
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
|
4457
|
+
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
|
4432
4458
|
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4433
|
-
|
|
4459
|
+
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
|
4460
|
+
// run fewer test cases permuted
|
|
4461
|
+
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
|
4462
|
+
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4463
|
+
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
|
4464
|
+
}
|
|
4434
4465
|
}
|
|
4435
4466
|
}
|
|
4436
4467
|
}
|
|
@@ -4507,6 +4538,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
4507
4538
|
}
|
|
4508
4539
|
}
|
|
4509
4540
|
|
|
4541
|
+
for (int kv : { 4096, 8192, 16384, }) {
|
|
4542
|
+
for (int hs : { 64, 128, }) {
|
|
4543
|
+
for (int nr : { 1, 4, }) {
|
|
4544
|
+
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
|
4545
|
+
}
|
|
4546
|
+
}
|
|
4547
|
+
}
|
|
4548
|
+
|
|
4510
4549
|
return test_cases;
|
|
4511
4550
|
}
|
|
4512
4551
|
|
|
@@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) {
|
|
|
19
19
|
#endif
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
+
#define U8C(x) (const char*)(u8##x)
|
|
23
|
+
|
|
22
24
|
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
|
|
23
25
|
common_chat_msg msg;
|
|
24
26
|
msg.role = role;
|
|
@@ -35,6 +37,8 @@ int main(void) {
|
|
|
35
37
|
{"assistant", " I am an assistant "},
|
|
36
38
|
{"user", "Another question"},
|
|
37
39
|
};
|
|
40
|
+
|
|
41
|
+
// std::string wrong = /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
|
|
38
42
|
struct TestCase {
|
|
39
43
|
std::string name;
|
|
40
44
|
std::string template_str;
|
|
@@ -177,24 +181,25 @@ int main(void) {
|
|
|
177
181
|
},
|
|
178
182
|
{
|
|
179
183
|
/* .name= */ "ChatGLM4",
|
|
180
|
-
/* .template_str= */
|
|
184
|
+
/* .template_str= */ U8C("[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
|
|
181
185
|
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
182
186
|
/* .expected_output_jinja= */ "",
|
|
183
187
|
/* .bos_token= */ "",
|
|
184
188
|
/* .eos_token= */ "",
|
|
185
189
|
},
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
190
|
+
// TODO @ngxson : GLMEdge produces poor result without `[gMASK]<sop>`, so we're temporarily using GLM4 template for it. We should fix this in the future.
|
|
191
|
+
// {
|
|
192
|
+
// /* .name= */ "GLMEdge",
|
|
193
|
+
// /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
|
|
194
|
+
// /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
195
|
+
// /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
196
|
+
// /* .bos_token= */ "",
|
|
197
|
+
// /* .eos_token= */ "",
|
|
198
|
+
// },
|
|
194
199
|
{
|
|
195
200
|
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
|
|
196
|
-
/* .template_str= */
|
|
197
|
-
/* .expected_output= */
|
|
201
|
+
/* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
|
|
202
|
+
/* .expected_output= */ U8C("You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>"),
|
|
198
203
|
/* .expected_output_jinja= */ "",
|
|
199
204
|
/* .bos_token= */ "",
|
|
200
205
|
/* .eos_token= */ "",
|
|
@@ -202,7 +207,7 @@ int main(void) {
|
|
|
202
207
|
{
|
|
203
208
|
/* .name= */ "DeepSeek-V2",
|
|
204
209
|
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
|
205
|
-
/* .expected_output= */
|
|
210
|
+
/* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
|
|
206
211
|
/* .expected_output_jinja= */ "",
|
|
207
212
|
/* .bos_token= */ "",
|
|
208
213
|
/* .eos_token= */ "<|end▁of▁sentence|>",
|
|
@@ -256,7 +261,7 @@ int main(void) {
|
|
|
256
261
|
},
|
|
257
262
|
{
|
|
258
263
|
/* .name= */ "Infinigence/Megrez-3B-Instruct",
|
|
259
|
-
/* .template_str= */
|
|
264
|
+
/* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
|
|
260
265
|
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
|
261
266
|
/* .expected_output_jinja= */ "",
|
|
262
267
|
/* .bos_token= */ "",
|
|
@@ -270,6 +275,22 @@ int main(void) {
|
|
|
270
275
|
/* .bos_token= */ "",
|
|
271
276
|
/* .eos_token= */ "",
|
|
272
277
|
},
|
|
278
|
+
{
|
|
279
|
+
/* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
|
|
280
|
+
/* .template_str= */ "<s>{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n {%- set name = tool.function.name %}\n {%- set description = tool.function.description|default('') %}\n {%- set parameters = tool.function.parameters|tojson %}\n {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n {{- tools_prefix }}\n {%- for tool in tools %}\n {{- __render_tool(tool) }}\n {%- endfor %}\n {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n {{- names.assistant }}\n {%- set call = message['function_call'] %}\n {%- if call %}\n {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n {%- else %}\n {{- ' ' + message.content + '\\n\\n' }}\n {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {{- __render_user_message(message) }}\n {%- endif %}\n {%- if message.role == 'assistant' and not loop.last %}\n {{- __render_assistant_message(message) }}\n {%- endif %}\n {%- if message.role == 'tool' %}\n {{- __render_tool_message(message) }}\n {%- endif %}\n {%- if loop.last %}\n {{- ' Ассистент:[SEP]' }}\n {%- endif %}\n{%- endfor %}\n",
|
|
281
|
+
/* .expected_output= */ "<s> Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
|
|
282
|
+
/* .expected_output_jinja= */ "<s> Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
|
|
283
|
+
/* .bos_token= */ "",
|
|
284
|
+
/* .eos_token= */ "",
|
|
285
|
+
},
|
|
286
|
+
{
|
|
287
|
+
/* .name= */ "inclusionAI/Ling-lite",
|
|
288
|
+
/* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
|
|
289
|
+
/* .expected_output= */ "<role>SYSTEM</role>You are a helpful assistant<role>HUMAN</role>Hello<role>ASSISTANT</role>Hi there<role>HUMAN</role>Who are you<role>ASSISTANT</role> I am an assistant <role>HUMAN</role>Another question<role>ASSISTANT</role>",
|
|
290
|
+
/* .expected_output_jinja= */ "",
|
|
291
|
+
/* .bos_token= */ "",
|
|
292
|
+
/* .eos_token= */ "",
|
|
293
|
+
},
|
|
273
294
|
};
|
|
274
295
|
std::vector<char> formatted_chat(1024);
|
|
275
296
|
int32_t res;
|
|
@@ -11,8 +11,9 @@
|
|
|
11
11
|
#include <string>
|
|
12
12
|
|
|
13
13
|
#include "chat.h"
|
|
14
|
-
|
|
15
|
-
#include "unicode.h"
|
|
14
|
+
|
|
15
|
+
#include "../src/unicode.h"
|
|
16
|
+
#include "../src/llama-grammar.h"
|
|
16
17
|
|
|
17
18
|
using json = nlohmann::ordered_json;
|
|
18
19
|
|
|
@@ -569,6 +570,7 @@ static void test_template_output_parsers() {
|
|
|
569
570
|
{
|
|
570
571
|
// Not supported yet
|
|
571
572
|
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
|
|
573
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
572
574
|
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
573
575
|
}
|
|
574
576
|
{
|
|
@@ -665,6 +667,7 @@ static void test_template_output_parsers() {
|
|
|
665
667
|
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
|
|
666
668
|
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
|
667
669
|
|
|
670
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
668
671
|
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
669
672
|
assert_equals(
|
|
670
673
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
@@ -793,6 +796,7 @@ static void test_template_output_parsers() {
|
|
|
793
796
|
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
|
|
794
797
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
795
798
|
|
|
799
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
796
800
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
797
801
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
798
802
|
common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
|
|
@@ -815,6 +819,7 @@ static void test_template_output_parsers() {
|
|
|
815
819
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
816
820
|
|
|
817
821
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
822
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
818
823
|
|
|
819
824
|
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
820
825
|
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
@@ -824,6 +829,8 @@ static void test_template_output_parsers() {
|
|
|
824
829
|
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
|
|
825
830
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
826
831
|
|
|
832
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
833
|
+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
827
834
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
828
835
|
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
829
836
|
|
|
@@ -851,6 +858,7 @@ static void test_template_output_parsers() {
|
|
|
851
858
|
auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
|
|
852
859
|
std::vector<std::string> end_tokens{ "<|eot_id|>" };
|
|
853
860
|
|
|
861
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
854
862
|
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
855
863
|
|
|
856
864
|
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
@@ -862,6 +870,7 @@ static void test_template_output_parsers() {
|
|
|
862
870
|
auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
|
|
863
871
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
864
872
|
|
|
873
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
865
874
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
866
875
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
867
876
|
|
|
@@ -891,6 +900,7 @@ static void test_template_output_parsers() {
|
|
|
891
900
|
auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
|
|
892
901
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
893
902
|
|
|
903
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
894
904
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
895
905
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
896
906
|
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
# undef NDEBUG
|
|
3
3
|
#endif
|
|
4
4
|
|
|
5
|
-
#include "unicode.h"
|
|
6
5
|
#include "sampling.h"
|
|
7
6
|
|
|
8
7
|
#include <cassert>
|
|
@@ -84,7 +83,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|
|
84
83
|
|
|
85
84
|
fprintf(stderr,
|
|
86
85
|
"\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
|
|
87
|
-
"command: ./
|
|
86
|
+
"command: ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
|
|
88
87
|
"test-grammar-integration.string.txt\n\n");
|
|
89
88
|
} else {
|
|
90
89
|
fprintf(stdout, "✅︎\n");
|
|
@@ -1086,6 +1085,65 @@ static void test_json_schema() {
|
|
|
1086
1085
|
});
|
|
1087
1086
|
}
|
|
1088
1087
|
|
|
1088
|
+
static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
|
|
1089
|
+
auto n_vocab = tok_arr.size;
|
|
1090
|
+
|
|
1091
|
+
tok_arr.selected = -1;
|
|
1092
|
+
tok_arr.sorted = false;
|
|
1093
|
+
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
|
|
1094
|
+
tok_arr.data[token_id].id = token_id;
|
|
1095
|
+
tok_arr.data[token_id].logit = 0.0f;
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
tok_arr.data[selected].logit = 100.0f;
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
static void test_sampler_chain(void) {
|
|
1102
|
+
auto sparams = llama_sampler_chain_default_params();
|
|
1103
|
+
sparams.no_perf = false;
|
|
1104
|
+
llama_sampler * sampler = llama_sampler_chain_init(sparams);
|
|
1105
|
+
|
|
1106
|
+
const auto grammar_data = R"(%llguidance {}
|
|
1107
|
+
start: /[A-Z ]*/)";
|
|
1108
|
+
|
|
1109
|
+
llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
|
|
1110
|
+
llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
|
|
1111
|
+
|
|
1112
|
+
auto input = "ALL YOUR BASE ARE BELONG TO US";
|
|
1113
|
+
auto tokens = common_tokenize(vocab, input, false, false);
|
|
1114
|
+
|
|
1115
|
+
auto n_vocab = llama_vocab_n_tokens(vocab);
|
|
1116
|
+
|
|
1117
|
+
std::vector<llama_token_data> cur;
|
|
1118
|
+
cur.reserve(n_vocab);
|
|
1119
|
+
for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
|
|
1120
|
+
cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
|
|
1121
|
+
}
|
|
1122
|
+
auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
|
|
1123
|
+
|
|
1124
|
+
for (const auto token : tokens) {
|
|
1125
|
+
one_hot(tok_arr, token);
|
|
1126
|
+
|
|
1127
|
+
fprintf(stderr, "applying token: %d\n", token);
|
|
1128
|
+
llama_sampler_apply(sampler, &tok_arr);
|
|
1129
|
+
|
|
1130
|
+
auto idx = tok_arr.selected;
|
|
1131
|
+
fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
|
|
1132
|
+
assert(cur[tok_arr.selected].id == token);
|
|
1133
|
+
llama_sampler_accept(sampler, token);
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
auto tok_eos = llama_vocab_eot(vocab);
|
|
1137
|
+
if (tok_eos == LLAMA_TOKEN_NULL) {
|
|
1138
|
+
tok_eos = llama_vocab_eos(vocab);
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
one_hot(tok_arr, tok_eos);
|
|
1142
|
+
|
|
1143
|
+
llama_sampler_apply(sampler, &tok_arr);
|
|
1144
|
+
assert(cur[tok_arr.selected].id == tok_eos);
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1089
1147
|
int main(int argc, const char ** argv) {
|
|
1090
1148
|
fprintf(stdout, "Running llguidance integration tests...\n");
|
|
1091
1149
|
|
|
@@ -1135,6 +1193,9 @@ int main(int argc, const char ** argv) {
|
|
|
1135
1193
|
test_special_chars();
|
|
1136
1194
|
test_quantifiers();
|
|
1137
1195
|
test_json_schema();
|
|
1196
|
+
|
|
1197
|
+
test_sampler_chain();
|
|
1198
|
+
|
|
1138
1199
|
fprintf(stdout, "All tests passed.\n");
|
|
1139
1200
|
return 0;
|
|
1140
1201
|
}
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
#include "json-schema-to-grammar.h"
|
|
6
6
|
|
|
7
|
-
#include "llama-grammar.h"
|
|
7
|
+
#include "../src/llama-grammar.h"
|
|
8
8
|
|
|
9
9
|
#include <cassert>
|
|
10
10
|
#include <fstream>
|
|
@@ -597,6 +597,22 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
|
|
597
597
|
)"""
|
|
598
598
|
});
|
|
599
599
|
|
|
600
|
+
test({
|
|
601
|
+
SUCCESS,
|
|
602
|
+
"maxItems 0",
|
|
603
|
+
R"""({
|
|
604
|
+
"items": {
|
|
605
|
+
"type": "boolean"
|
|
606
|
+
},
|
|
607
|
+
"maxItems": 0
|
|
608
|
+
})""",
|
|
609
|
+
R"""(
|
|
610
|
+
boolean ::= ("true" | "false") space
|
|
611
|
+
root ::= "[" space "]" space
|
|
612
|
+
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
|
613
|
+
)"""
|
|
614
|
+
});
|
|
615
|
+
|
|
600
616
|
test({
|
|
601
617
|
SUCCESS,
|
|
602
618
|
"maxItems 1",
|