@fugood/llama.node 0.3.15 → 0.3.17
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +5 -0
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +8 -0
- package/src/LlamaCompletionWorker.h +1 -0
- package/src/LlamaContext.cpp +3 -2
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
- package/src/llama.cpp/.github/workflows/build.yml +70 -27
- package/src/llama.cpp/.github/workflows/docker.yml +6 -6
- package/src/llama.cpp/.github/workflows/server.yml +7 -11
- package/src/llama.cpp/CMakeLists.txt +23 -1
- package/src/llama.cpp/common/CMakeLists.txt +6 -3
- package/src/llama.cpp/common/arg.cpp +809 -105
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +1 -1
- package/src/llama.cpp/common/common.cpp +31 -521
- package/src/llama.cpp/common/common.h +17 -36
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +30 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
- package/src/llama.cpp/common/minja/minja.hpp +119 -93
- package/src/llama.cpp/common/sampling.cpp +3 -0
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +0 -9
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
- package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
- package/src/llama.cpp/examples/llava/clip.h +39 -22
- package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/examples/llava/llava.cpp +64 -52
- package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
- package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
- package/src/llama.cpp/examples/llava/mtmd.h +168 -0
- package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
- package/src/llama.cpp/examples/main/main.cpp +16 -5
- package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
- package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
- package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
- package/src/llama.cpp/examples/run/run.cpp +14 -28
- package/src/llama.cpp/examples/server/httplib.h +313 -247
- package/src/llama.cpp/examples/server/server.cpp +243 -139
- package/src/llama.cpp/examples/server/utils.hpp +51 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +14 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +66 -99
- package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
- package/src/llama.cpp/ggml/src/ggml.c +141 -245
- package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
- package/src/llama.cpp/include/llama.h +30 -11
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +2 -0
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/src/CMakeLists.txt +3 -2
- package/src/llama.cpp/src/llama-adapter.cpp +37 -1
- package/src/llama.cpp/src/llama-arch.cpp +161 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-chat.cpp +82 -17
- package/src/llama.cpp/src/llama-chat.h +6 -2
- package/src/llama.cpp/src/llama-context.cpp +108 -92
- package/src/llama.cpp/src/llama-context.h +1 -2
- package/src/llama.cpp/src/llama-graph.cpp +189 -119
- package/src/llama.cpp/src/llama-graph.h +26 -6
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
- package/src/llama.cpp/src/llama-kv-cache.h +41 -115
- package/src/llama.cpp/src/llama-memory.h +1 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model.cpp +1544 -291
- package/src/llama.cpp/src/llama-model.h +13 -1
- package/src/llama.cpp/src/llama-quant.cpp +29 -8
- package/src/llama.cpp/src/llama-sampling.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.cpp +44 -6
- package/src/llama.cpp/src/llama.cpp +1 -1
- package/src/llama.cpp/tests/CMakeLists.txt +43 -30
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
- package/src/llama.cpp/tests/test-chat.cpp +12 -2
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
|
@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
|
|
|
271
271
|
}
|
|
272
272
|
}
|
|
273
273
|
|
|
274
|
+
static std::string var_to_str(ggml_scale_mode mode) {
|
|
275
|
+
switch (mode) {
|
|
276
|
+
case GGML_SCALE_MODE_NEAREST: return "nearest";
|
|
277
|
+
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
|
|
278
|
+
default: return std::to_string(mode);
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
274
282
|
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
|
|
275
283
|
|
|
276
284
|
#define VARS_TO_STR1(a) VAR_TO_STR(a)
|
|
@@ -1463,11 +1471,13 @@ struct test_cpy : public test_case {
|
|
|
1463
1471
|
const ggml_type type_src;
|
|
1464
1472
|
const ggml_type type_dst;
|
|
1465
1473
|
const std::array<int64_t, 4> ne;
|
|
1466
|
-
const std::array<int64_t, 4>
|
|
1474
|
+
const std::array<int64_t, 4> permute_src;
|
|
1475
|
+
const std::array<int64_t, 4> permute_dst;
|
|
1467
1476
|
bool _src_use_permute;
|
|
1477
|
+
bool _dst_use_permute;
|
|
1468
1478
|
|
|
1469
1479
|
std::string vars() override {
|
|
1470
|
-
return
|
|
1480
|
+
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
|
|
1471
1481
|
}
|
|
1472
1482
|
|
|
1473
1483
|
double max_nmse_err() override {
|
|
@@ -1480,9 +1490,11 @@ struct test_cpy : public test_case {
|
|
|
1480
1490
|
|
|
1481
1491
|
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
|
|
1482
1492
|
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
|
1483
|
-
std::array<int64_t, 4>
|
|
1484
|
-
|
|
1485
|
-
|
|
1493
|
+
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
|
|
1494
|
+
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
|
|
1495
|
+
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
|
|
1496
|
+
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
|
|
1497
|
+
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
|
|
1486
1498
|
|
|
1487
1499
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1488
1500
|
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
|
|
@@ -1490,13 +1502,18 @@ struct test_cpy : public test_case {
|
|
|
1490
1502
|
ggml_set_name(src, "src");
|
|
1491
1503
|
|
|
1492
1504
|
if (_src_use_permute) {
|
|
1493
|
-
src = ggml_permute(ctx, src,
|
|
1505
|
+
src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
|
|
1494
1506
|
ggml_set_name(src, "src_permuted");
|
|
1495
1507
|
}
|
|
1496
1508
|
|
|
1497
|
-
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
|
|
1509
|
+
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
|
|
1498
1510
|
ggml_set_name(dst, "dst");
|
|
1499
1511
|
|
|
1512
|
+
if (_dst_use_permute) {
|
|
1513
|
+
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
|
|
1514
|
+
ggml_set_name(dst, "dst_permuted");
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1500
1517
|
ggml_tensor * out = ggml_cpy(ctx, src, dst);
|
|
1501
1518
|
ggml_set_name(out, "out");
|
|
1502
1519
|
|
|
@@ -1964,9 +1981,10 @@ struct test_mul_mat : public test_case {
|
|
|
1964
1981
|
const std::array<int64_t, 2> bs; // dims 3 and 4
|
|
1965
1982
|
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
|
|
1966
1983
|
const std::array<int64_t, 4> per; // permutation of dimensions
|
|
1984
|
+
const bool v; // whether a is a non-contiguous view
|
|
1967
1985
|
|
|
1968
1986
|
std::string vars() override {
|
|
1969
|
-
return
|
|
1987
|
+
return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
|
|
1970
1988
|
}
|
|
1971
1989
|
|
|
1972
1990
|
double max_nmse_err() override {
|
|
@@ -1986,8 +2004,9 @@ struct test_mul_mat : public test_case {
|
|
|
1986
2004
|
int64_t m = 32, int64_t n = 32, int64_t k = 32,
|
|
1987
2005
|
std::array<int64_t, 2> bs = {10, 10},
|
|
1988
2006
|
std::array<int64_t, 2> nr = {2, 2},
|
|
1989
|
-
std::array<int64_t, 4> per = {0, 1, 2, 3}
|
|
1990
|
-
|
|
2007
|
+
std::array<int64_t, 4> per = {0, 1, 2, 3},
|
|
2008
|
+
bool v = false)
|
|
2009
|
+
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
|
|
1991
2010
|
|
|
1992
2011
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
1993
2012
|
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
|
|
@@ -1997,6 +2016,7 @@ struct test_mul_mat : public test_case {
|
|
|
1997
2016
|
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
|
|
1998
2017
|
if (npermuted > 0) {
|
|
1999
2018
|
GGML_ASSERT(npermuted == 2);
|
|
2019
|
+
GGML_ASSERT(!v); // not handled
|
|
2000
2020
|
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
|
|
2001
2021
|
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
|
|
2002
2022
|
|
|
@@ -2020,7 +2040,13 @@ struct test_mul_mat : public test_case {
|
|
|
2020
2040
|
ggml_set_name(a, "a_permuted");
|
|
2021
2041
|
ggml_set_name(b, "b_permuted");
|
|
2022
2042
|
} else {
|
|
2023
|
-
|
|
2043
|
+
|
|
2044
|
+
if (v) {
|
|
2045
|
+
a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
|
|
2046
|
+
a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
|
|
2047
|
+
} else {
|
|
2048
|
+
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
|
|
2049
|
+
}
|
|
2024
2050
|
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
|
2025
2051
|
if (!ggml_is_quantized(type_a)) {
|
|
2026
2052
|
if (bs[1] == 1 && nr[1] == 1) {
|
|
@@ -2045,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
|
|
|
2045
2071
|
const ggml_type type_b;
|
|
2046
2072
|
const int n_mats;
|
|
2047
2073
|
const int n_used;
|
|
2048
|
-
const bool b; //
|
|
2074
|
+
const bool b; // broadcast b matrix
|
|
2049
2075
|
const int64_t m;
|
|
2050
2076
|
const int64_t n;
|
|
2051
2077
|
const int64_t k;
|
|
@@ -2580,6 +2606,8 @@ struct test_rope : public test_case {
|
|
|
2580
2606
|
} else {
|
|
2581
2607
|
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
|
2582
2608
|
}
|
|
2609
|
+
|
|
2610
|
+
// TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
|
|
2583
2611
|
}
|
|
2584
2612
|
ggml_set_name(out, "out");
|
|
2585
2613
|
|
|
@@ -2930,15 +2958,16 @@ struct test_upscale : public test_case {
|
|
|
2930
2958
|
const std::array<int64_t, 4> ne;
|
|
2931
2959
|
const int32_t scale_factor;
|
|
2932
2960
|
const bool transpose;
|
|
2961
|
+
const ggml_scale_mode mode;
|
|
2933
2962
|
|
|
2934
2963
|
std::string vars() override {
|
|
2935
|
-
return
|
|
2964
|
+
return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
|
|
2936
2965
|
}
|
|
2937
2966
|
|
|
2938
2967
|
test_upscale(ggml_type type = GGML_TYPE_F32,
|
|
2939
2968
|
std::array<int64_t, 4> ne = {512, 512, 3, 1},
|
|
2940
|
-
int32_t scale_factor = 2, bool transpose = false)
|
|
2941
|
-
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
|
|
2969
|
+
int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
|
|
2970
|
+
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
|
|
2942
2971
|
|
|
2943
2972
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
2944
2973
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
@@ -2949,7 +2978,7 @@ struct test_upscale : public test_case {
|
|
|
2949
2978
|
ggml_set_name(a, "a_transposed");
|
|
2950
2979
|
}
|
|
2951
2980
|
|
|
2952
|
-
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
|
|
2981
|
+
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
|
|
2953
2982
|
ggml_set_name(out, "out");
|
|
2954
2983
|
|
|
2955
2984
|
return out;
|
|
@@ -2961,21 +2990,23 @@ struct test_upscale_ext : public test_case {
|
|
|
2961
2990
|
const ggml_type type;
|
|
2962
2991
|
const std::array<int64_t, 4> ne;
|
|
2963
2992
|
const std::array<int64_t, 4> ne_tgt;
|
|
2993
|
+
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
|
|
2964
2994
|
|
|
2965
2995
|
std::string vars() override {
|
|
2966
|
-
return
|
|
2996
|
+
return VARS_TO_STR4(type, ne, ne_tgt, mode);
|
|
2967
2997
|
}
|
|
2968
2998
|
|
|
2969
2999
|
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
|
|
2970
3000
|
std::array<int64_t, 4> ne = {2, 5, 7, 11},
|
|
2971
|
-
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13}
|
|
2972
|
-
|
|
3001
|
+
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
|
|
3002
|
+
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
|
|
3003
|
+
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
|
|
2973
3004
|
|
|
2974
3005
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
2975
3006
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
|
2976
3007
|
ggml_set_name(a, "a");
|
|
2977
3008
|
|
|
2978
|
-
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
|
|
3009
|
+
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
|
|
2979
3010
|
ggml_set_name(out, "out");
|
|
2980
3011
|
|
|
2981
3012
|
return out;
|
|
@@ -3199,7 +3230,8 @@ struct test_leaky_relu : public test_case {
|
|
|
3199
3230
|
|
|
3200
3231
|
// GGML_OP_FLASH_ATTN_EXT
|
|
3201
3232
|
struct test_flash_attn_ext : public test_case {
|
|
3202
|
-
const int64_t
|
|
3233
|
+
const int64_t hsk; // K head size
|
|
3234
|
+
const int64_t hsv; // V head size
|
|
3203
3235
|
const int64_t nh; // num heads
|
|
3204
3236
|
const int64_t nr; // repeat in Q, tests for grouped-query attention
|
|
3205
3237
|
const int64_t kv; // kv size
|
|
@@ -3215,7 +3247,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3215
3247
|
std::array<int32_t, 4> permute;
|
|
3216
3248
|
|
|
3217
3249
|
std::string vars() override {
|
|
3218
|
-
return
|
|
3250
|
+
return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
|
3219
3251
|
}
|
|
3220
3252
|
|
|
3221
3253
|
double max_nmse_err() override {
|
|
@@ -3225,17 +3257,18 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3225
3257
|
uint64_t op_flops(ggml_tensor * t) override {
|
|
3226
3258
|
GGML_UNUSED(t);
|
|
3227
3259
|
// Just counting matmul costs:
|
|
3228
|
-
// Q*K^T is nb x
|
|
3229
|
-
return 2 *
|
|
3260
|
+
// Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
|
|
3261
|
+
return 2 * nh*nr * nb * (hsk + hsv) * kv;
|
|
3230
3262
|
}
|
|
3231
3263
|
|
|
3232
|
-
test_flash_attn_ext(int64_t
|
|
3264
|
+
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
|
|
3233
3265
|
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
|
3234
3266
|
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
|
3235
|
-
:
|
|
3267
|
+
: hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
|
3236
3268
|
|
|
3237
3269
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
|
3238
|
-
const int64_t
|
|
3270
|
+
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
|
|
3271
|
+
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
|
|
3239
3272
|
|
|
3240
3273
|
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
|
|
3241
3274
|
int64_t ne[4] = {ne0, ne1, ne2, ne3};
|
|
@@ -3250,13 +3283,13 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3250
3283
|
return t;
|
|
3251
3284
|
};
|
|
3252
3285
|
|
|
3253
|
-
ggml_tensor * q = create_permuted(GGML_TYPE_F32,
|
|
3286
|
+
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
|
|
3254
3287
|
ggml_set_name(q, "q");
|
|
3255
3288
|
|
|
3256
|
-
ggml_tensor * k = create_permuted(type_KV,
|
|
3289
|
+
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
|
|
3257
3290
|
ggml_set_name(k, "k");
|
|
3258
3291
|
|
|
3259
|
-
ggml_tensor * v = create_permuted(type_KV,
|
|
3292
|
+
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
|
|
3260
3293
|
ggml_set_name(v, "v");
|
|
3261
3294
|
|
|
3262
3295
|
ggml_tensor * m = nullptr;
|
|
@@ -3265,7 +3298,7 @@ struct test_flash_attn_ext : public test_case {
|
|
|
3265
3298
|
ggml_set_name(m, "m");
|
|
3266
3299
|
}
|
|
3267
3300
|
|
|
3268
|
-
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(
|
|
3301
|
+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
|
|
3269
3302
|
ggml_flash_attn_ext_set_prec(out, prec);
|
|
3270
3303
|
ggml_set_name(out, "out");
|
|
3271
3304
|
|
|
@@ -3995,14 +4028,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
3995
4028
|
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
|
|
3996
4029
|
}
|
|
3997
4030
|
|
|
3998
|
-
|
|
4031
|
+
// same-type copy
|
|
4032
|
+
for (ggml_type type : all_types) {
|
|
4033
|
+
const auto nk = ggml_blck_size(type);
|
|
4034
|
+
|
|
4035
|
+
for (int k = 1; k < 4; ++k) {
|
|
4036
|
+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
|
|
4037
|
+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
|
|
4038
|
+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
|
|
4039
|
+
}
|
|
4040
|
+
}
|
|
4041
|
+
|
|
4042
|
+
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
|
|
3999
4043
|
for (ggml_type type_dst : all_types) {
|
|
4000
4044
|
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
|
4001
4045
|
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
|
|
4002
4046
|
}
|
|
4003
4047
|
}
|
|
4004
|
-
for (ggml_type
|
|
4005
|
-
for (ggml_type
|
|
4048
|
+
for (ggml_type type_src : all_types) {
|
|
4049
|
+
for (ggml_type type_dst : {GGML_TYPE_F32}) {
|
|
4006
4050
|
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
|
4007
4051
|
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
|
|
4008
4052
|
}
|
|
@@ -4140,6 +4184,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4140
4184
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
|
4141
4185
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
|
|
4142
4186
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
|
|
4187
|
+
|
|
4188
|
+
// test cases with large ne00/ne10 to cover stream-k fixup
|
|
4189
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
|
|
4190
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
|
|
4191
|
+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
|
|
4143
4192
|
}
|
|
4144
4193
|
}
|
|
4145
4194
|
for (ggml_type type_a : other_types) {
|
|
@@ -4175,6 +4224,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4175
4224
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
|
|
4176
4225
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
|
|
4177
4226
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
|
|
4227
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
|
4228
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
|
|
4229
|
+
|
|
4230
|
+
for (auto bs : {1,2,4,8}) {
|
|
4231
|
+
for (auto nr : {1,4}) {
|
|
4232
|
+
for (uint32_t m = 0; m < 2; ++m) {
|
|
4233
|
+
for (uint32_t k = 0; k < 2; ++k) {
|
|
4234
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
|
|
4235
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
|
|
4236
|
+
}
|
|
4237
|
+
}
|
|
4238
|
+
}
|
|
4239
|
+
}
|
|
4178
4240
|
|
|
4179
4241
|
// sycl backend will limit task global_range < MAX_INT
|
|
4180
4242
|
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
|
|
@@ -4355,12 +4417,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4355
4417
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
|
4356
4418
|
}
|
|
4357
4419
|
|
|
4420
|
+
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
|
4421
|
+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
|
|
4422
|
+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
|
|
4423
|
+
test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
|
|
4424
|
+
}
|
|
4425
|
+
|
|
4358
4426
|
test_cases.emplace_back(new test_sum());
|
|
4359
4427
|
test_cases.emplace_back(new test_sum_rows());
|
|
4360
4428
|
test_cases.emplace_back(new test_mean());
|
|
4361
|
-
test_cases.emplace_back(new test_upscale());
|
|
4362
|
-
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
|
|
4363
|
-
test_cases.emplace_back(new test_upscale_ext());
|
|
4364
4429
|
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
|
4365
4430
|
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
|
4366
4431
|
test_cases.emplace_back(new test_acc());
|
|
@@ -4370,27 +4435,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|
|
4370
4435
|
test_cases.emplace_back(new test_timestep_embedding());
|
|
4371
4436
|
test_cases.emplace_back(new test_leaky_relu());
|
|
4372
4437
|
|
|
4373
|
-
for (int
|
|
4374
|
-
for (
|
|
4375
|
-
|
|
4376
|
-
|
|
4377
|
-
|
|
4378
|
-
|
|
4379
|
-
|
|
4380
|
-
|
|
4381
|
-
|
|
4382
|
-
|
|
4383
|
-
|
|
4384
|
-
|
|
4385
|
-
|
|
4386
|
-
|
|
4387
|
-
|
|
4388
|
-
|
|
4389
|
-
|
|
4390
|
-
|
|
4391
|
-
if (
|
|
4438
|
+
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
|
|
4439
|
+
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
|
|
4440
|
+
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
|
4441
|
+
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
|
|
4442
|
+
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
|
4443
|
+
|
|
4444
|
+
for (bool mask : { true, false } ) {
|
|
4445
|
+
for (float max_bias : { 0.0f, 8.0f }) {
|
|
4446
|
+
if (!mask && max_bias > 0.0f) continue;
|
|
4447
|
+
for (float logit_softcap : {0.0f, 10.0f}) {
|
|
4448
|
+
if (hsk != 128 && logit_softcap != 0.0f) continue;
|
|
4449
|
+
for (int nh : { 4, }) {
|
|
4450
|
+
for (int nr : { 1, 4, 16 }) {
|
|
4451
|
+
if (nr == 16 && hsk != 128) continue;
|
|
4452
|
+
for (int kv : { 512, 1024, }) {
|
|
4453
|
+
if (nr != 1 && kv != 512) continue;
|
|
4454
|
+
for (int nb : { 1, 3, 32, 35, }) {
|
|
4455
|
+
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
|
4456
|
+
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
|
4457
|
+
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
|
4392
4458
|
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4393
|
-
|
|
4459
|
+
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
|
4460
|
+
// run fewer test cases permuted
|
|
4461
|
+
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
|
4462
|
+
test_cases.emplace_back(new test_flash_attn_ext(
|
|
4463
|
+
hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
|
4464
|
+
}
|
|
4394
4465
|
}
|
|
4395
4466
|
}
|
|
4396
4467
|
}
|
|
@@ -4444,6 +4515,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
4444
4515
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
|
4445
4516
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
|
|
4446
4517
|
|
|
4518
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
|
4519
|
+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
|
|
4520
|
+
|
|
4447
4521
|
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
|
4448
4522
|
for (ggml_type type_a : all_types) {
|
|
4449
4523
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
|
@@ -4464,6 +4538,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|
|
4464
4538
|
}
|
|
4465
4539
|
}
|
|
4466
4540
|
|
|
4541
|
+
for (int kv : { 4096, 8192, 16384, }) {
|
|
4542
|
+
for (int hs : { 64, 128, }) {
|
|
4543
|
+
for (int nr : { 1, 4, }) {
|
|
4544
|
+
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
|
4545
|
+
}
|
|
4546
|
+
}
|
|
4547
|
+
}
|
|
4548
|
+
|
|
4467
4549
|
return test_cases;
|
|
4468
4550
|
}
|
|
4469
4551
|
|
|
@@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) {
|
|
|
19
19
|
#endif
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
+
#define U8C(x) (const char*)(u8##x)
|
|
23
|
+
|
|
22
24
|
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
|
|
23
25
|
common_chat_msg msg;
|
|
24
26
|
msg.role = role;
|
|
@@ -35,6 +37,8 @@ int main(void) {
|
|
|
35
37
|
{"assistant", " I am an assistant "},
|
|
36
38
|
{"user", "Another question"},
|
|
37
39
|
};
|
|
40
|
+
|
|
41
|
+
// std::string wrong = /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
|
|
38
42
|
struct TestCase {
|
|
39
43
|
std::string name;
|
|
40
44
|
std::string template_str;
|
|
@@ -177,24 +181,25 @@ int main(void) {
|
|
|
177
181
|
},
|
|
178
182
|
{
|
|
179
183
|
/* .name= */ "ChatGLM4",
|
|
180
|
-
/* .template_str= */
|
|
184
|
+
/* .template_str= */ U8C("[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
|
|
181
185
|
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
182
186
|
/* .expected_output_jinja= */ "",
|
|
183
187
|
/* .bos_token= */ "",
|
|
184
188
|
/* .eos_token= */ "",
|
|
185
189
|
},
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
190
|
+
// TODO @ngxson : GLMEdge produces poor result without `[gMASK]<sop>`, so we're temporarily using GLM4 template for it. We should fix this in the future.
|
|
191
|
+
// {
|
|
192
|
+
// /* .name= */ "GLMEdge",
|
|
193
|
+
// /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
|
|
194
|
+
// /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
195
|
+
// /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
|
196
|
+
// /* .bos_token= */ "",
|
|
197
|
+
// /* .eos_token= */ "",
|
|
198
|
+
// },
|
|
194
199
|
{
|
|
195
200
|
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
|
|
196
|
-
/* .template_str= */
|
|
197
|
-
/* .expected_output= */
|
|
201
|
+
/* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
|
|
202
|
+
/* .expected_output= */ U8C("You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>"),
|
|
198
203
|
/* .expected_output_jinja= */ "",
|
|
199
204
|
/* .bos_token= */ "",
|
|
200
205
|
/* .eos_token= */ "",
|
|
@@ -202,7 +207,7 @@ int main(void) {
|
|
|
202
207
|
{
|
|
203
208
|
/* .name= */ "DeepSeek-V2",
|
|
204
209
|
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
|
205
|
-
/* .expected_output= */
|
|
210
|
+
/* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
|
|
206
211
|
/* .expected_output_jinja= */ "",
|
|
207
212
|
/* .bos_token= */ "",
|
|
208
213
|
/* .eos_token= */ "<|end▁of▁sentence|>",
|
|
@@ -256,7 +261,7 @@ int main(void) {
|
|
|
256
261
|
},
|
|
257
262
|
{
|
|
258
263
|
/* .name= */ "Infinigence/Megrez-3B-Instruct",
|
|
259
|
-
/* .template_str= */
|
|
264
|
+
/* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
|
|
260
265
|
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
|
261
266
|
/* .expected_output_jinja= */ "",
|
|
262
267
|
/* .bos_token= */ "",
|
|
@@ -270,6 +275,22 @@ int main(void) {
|
|
|
270
275
|
/* .bos_token= */ "",
|
|
271
276
|
/* .eos_token= */ "",
|
|
272
277
|
},
|
|
278
|
+
{
|
|
279
|
+
/* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
|
|
280
|
+
/* .template_str= */ "<s>{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n {%- set name = tool.function.name %}\n {%- set description = tool.function.description|default('') %}\n {%- set parameters = tool.function.parameters|tojson %}\n {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n {{- tools_prefix }}\n {%- for tool in tools %}\n {{- __render_tool(tool) }}\n {%- endfor %}\n {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n {{- names.assistant }}\n {%- set call = message['function_call'] %}\n {%- if call %}\n {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n {%- else %}\n {{- ' ' + message.content + '\\n\\n' }}\n {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {{- __render_user_message(message) }}\n {%- endif %}\n {%- if message.role == 'assistant' and not loop.last %}\n {{- __render_assistant_message(message) }}\n {%- endif %}\n {%- if message.role == 'tool' %}\n {{- __render_tool_message(message) }}\n {%- endif %}\n {%- if loop.last %}\n {{- ' Ассистент:[SEP]' }}\n {%- endif %}\n{%- endfor %}\n",
|
|
281
|
+
/* .expected_output= */ "<s> Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
|
|
282
|
+
/* .expected_output_jinja= */ "<s> Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
|
|
283
|
+
/* .bos_token= */ "",
|
|
284
|
+
/* .eos_token= */ "",
|
|
285
|
+
},
|
|
286
|
+
{
|
|
287
|
+
/* .name= */ "inclusionAI/Ling-lite",
|
|
288
|
+
/* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
|
|
289
|
+
/* .expected_output= */ "<role>SYSTEM</role>You are a helpful assistant<role>HUMAN</role>Hello<role>ASSISTANT</role>Hi there<role>HUMAN</role>Who are you<role>ASSISTANT</role> I am an assistant <role>HUMAN</role>Another question<role>ASSISTANT</role>",
|
|
290
|
+
/* .expected_output_jinja= */ "",
|
|
291
|
+
/* .bos_token= */ "",
|
|
292
|
+
/* .eos_token= */ "",
|
|
293
|
+
},
|
|
273
294
|
};
|
|
274
295
|
std::vector<char> formatted_chat(1024);
|
|
275
296
|
int32_t res;
|
|
@@ -11,8 +11,9 @@
|
|
|
11
11
|
#include <string>
|
|
12
12
|
|
|
13
13
|
#include "chat.h"
|
|
14
|
-
|
|
15
|
-
#include "unicode.h"
|
|
14
|
+
|
|
15
|
+
#include "../src/unicode.h"
|
|
16
|
+
#include "../src/llama-grammar.h"
|
|
16
17
|
|
|
17
18
|
using json = nlohmann::ordered_json;
|
|
18
19
|
|
|
@@ -569,6 +570,7 @@ static void test_template_output_parsers() {
|
|
|
569
570
|
{
|
|
570
571
|
// Not supported yet
|
|
571
572
|
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
|
|
573
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
572
574
|
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
573
575
|
}
|
|
574
576
|
{
|
|
@@ -665,6 +667,7 @@ static void test_template_output_parsers() {
|
|
|
665
667
|
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
|
|
666
668
|
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
|
667
669
|
|
|
670
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
668
671
|
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
669
672
|
assert_equals(
|
|
670
673
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
@@ -793,6 +796,7 @@ static void test_template_output_parsers() {
|
|
|
793
796
|
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
|
|
794
797
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
795
798
|
|
|
799
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
796
800
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
797
801
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
798
802
|
common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
|
|
@@ -815,6 +819,7 @@ static void test_template_output_parsers() {
|
|
|
815
819
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
816
820
|
|
|
817
821
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
822
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
818
823
|
|
|
819
824
|
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
820
825
|
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
@@ -824,6 +829,8 @@ static void test_template_output_parsers() {
|
|
|
824
829
|
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
|
|
825
830
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
826
831
|
|
|
832
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
833
|
+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
827
834
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
828
835
|
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
829
836
|
|
|
@@ -851,6 +858,7 @@ static void test_template_output_parsers() {
|
|
|
851
858
|
auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
|
|
852
859
|
std::vector<std::string> end_tokens{ "<|eot_id|>" };
|
|
853
860
|
|
|
861
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
854
862
|
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
855
863
|
|
|
856
864
|
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
@@ -862,6 +870,7 @@ static void test_template_output_parsers() {
|
|
|
862
870
|
auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
|
|
863
871
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
864
872
|
|
|
873
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
865
874
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
866
875
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
867
876
|
|
|
@@ -891,6 +900,7 @@ static void test_template_output_parsers() {
|
|
|
891
900
|
auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
|
|
892
901
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
893
902
|
|
|
903
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
894
904
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
895
905
|
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
896
906
|
|