@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
271
271
  }
272
272
  }
273
273
 
274
+ static std::string var_to_str(ggml_scale_mode mode) {
275
+ switch (mode) {
276
+ case GGML_SCALE_MODE_NEAREST: return "nearest";
277
+ case GGML_SCALE_MODE_BILINEAR: return "bilinear";
278
+ default: return std::to_string(mode);
279
+ }
280
+ }
281
+
274
282
  #define VAR_TO_STR(x) (#x "=" + var_to_str(x))
275
283
 
276
284
  #define VARS_TO_STR1(a) VAR_TO_STR(a)
@@ -1463,11 +1471,13 @@ struct test_cpy : public test_case {
1463
1471
  const ggml_type type_src;
1464
1472
  const ggml_type type_dst;
1465
1473
  const std::array<int64_t, 4> ne;
1466
- const std::array<int64_t, 4> permute;
1474
+ const std::array<int64_t, 4> permute_src;
1475
+ const std::array<int64_t, 4> permute_dst;
1467
1476
  bool _src_use_permute;
1477
+ bool _dst_use_permute;
1468
1478
 
1469
1479
  std::string vars() override {
1470
- return VARS_TO_STR4(type_src, type_dst, ne, permute);
1480
+ return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
1471
1481
  }
1472
1482
 
1473
1483
  double max_nmse_err() override {
@@ -1480,9 +1490,11 @@ struct test_cpy : public test_case {
1480
1490
 
1481
1491
  test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
1482
1492
  std::array<int64_t, 4> ne = {10, 10, 10, 1},
1483
- std::array<int64_t, 4> permute = {0, 0, 0, 0})
1484
- : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
1485
- _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
1493
+ std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
1494
+ std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
1495
+ : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
1496
+ _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
1497
+ _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
1486
1498
 
1487
1499
  ggml_tensor * build_graph(ggml_context * ctx) override {
1488
1500
  ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -1490,13 +1502,18 @@ struct test_cpy : public test_case {
1490
1502
  ggml_set_name(src, "src");
1491
1503
 
1492
1504
  if (_src_use_permute) {
1493
- src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
1505
+ src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
1494
1506
  ggml_set_name(src, "src_permuted");
1495
1507
  }
1496
1508
 
1497
- ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1509
+ ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
1498
1510
  ggml_set_name(dst, "dst");
1499
1511
 
1512
+ if (_dst_use_permute) {
1513
+ dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
1514
+ ggml_set_name(dst, "dst_permuted");
1515
+ }
1516
+
1500
1517
  ggml_tensor * out = ggml_cpy(ctx, src, dst);
1501
1518
  ggml_set_name(out, "out");
1502
1519
 
@@ -1964,9 +1981,10 @@ struct test_mul_mat : public test_case {
1964
1981
  const std::array<int64_t, 2> bs; // dims 3 and 4
1965
1982
  const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1966
1983
  const std::array<int64_t, 4> per; // permutation of dimensions
1984
+ const bool v; // whether a is a non-contiguous view
1967
1985
 
1968
1986
  std::string vars() override {
1969
- return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
1987
+ return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
1970
1988
  }
1971
1989
 
1972
1990
  double max_nmse_err() override {
@@ -1986,8 +2004,9 @@ struct test_mul_mat : public test_case {
1986
2004
  int64_t m = 32, int64_t n = 32, int64_t k = 32,
1987
2005
  std::array<int64_t, 2> bs = {10, 10},
1988
2006
  std::array<int64_t, 2> nr = {2, 2},
1989
- std::array<int64_t, 4> per = {0, 1, 2, 3})
1990
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
2007
+ std::array<int64_t, 4> per = {0, 1, 2, 3},
2008
+ bool v = false)
2009
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
1991
2010
 
1992
2011
  ggml_tensor * build_graph(ggml_context * ctx) override {
1993
2012
  // C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -1997,6 +2016,7 @@ struct test_mul_mat : public test_case {
1997
2016
  const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
1998
2017
  if (npermuted > 0) {
1999
2018
  GGML_ASSERT(npermuted == 2);
2019
+ GGML_ASSERT(!v); // not handled
2000
2020
  GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
2001
2021
  GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
2002
2022
 
@@ -2020,7 +2040,13 @@ struct test_mul_mat : public test_case {
2020
2040
  ggml_set_name(a, "a_permuted");
2021
2041
  ggml_set_name(b, "b_permuted");
2022
2042
  } else {
2023
- a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
2043
+
2044
+ if (v) {
2045
+ a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
2046
+ a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
2047
+ } else {
2048
+ a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
2049
+ }
2024
2050
  b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
2025
2051
  if (!ggml_is_quantized(type_a)) {
2026
2052
  if (bs[1] == 1 && nr[1] == 1) {
@@ -2045,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
2045
2071
  const ggml_type type_b;
2046
2072
  const int n_mats;
2047
2073
  const int n_used;
2048
- const bool b; // brodcast b matrix
2074
+ const bool b; // broadcast b matrix
2049
2075
  const int64_t m;
2050
2076
  const int64_t n;
2051
2077
  const int64_t k;
@@ -2580,6 +2606,8 @@ struct test_rope : public test_case {
2580
2606
  } else {
2581
2607
  out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2582
2608
  }
2609
+
2610
+ // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
2583
2611
  }
2584
2612
  ggml_set_name(out, "out");
2585
2613
 
@@ -2930,15 +2958,16 @@ struct test_upscale : public test_case {
2930
2958
  const std::array<int64_t, 4> ne;
2931
2959
  const int32_t scale_factor;
2932
2960
  const bool transpose;
2961
+ const ggml_scale_mode mode;
2933
2962
 
2934
2963
  std::string vars() override {
2935
- return VARS_TO_STR4(type, ne, scale_factor, transpose);
2964
+ return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
2936
2965
  }
2937
2966
 
2938
2967
  test_upscale(ggml_type type = GGML_TYPE_F32,
2939
2968
  std::array<int64_t, 4> ne = {512, 512, 3, 1},
2940
- int32_t scale_factor = 2, bool transpose = false)
2941
- : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
2969
+ int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
2970
+ : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
2942
2971
 
2943
2972
  ggml_tensor * build_graph(ggml_context * ctx) override {
2944
2973
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -2949,7 +2978,7 @@ struct test_upscale : public test_case {
2949
2978
  ggml_set_name(a, "a_transposed");
2950
2979
  }
2951
2980
 
2952
- ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
2981
+ ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
2953
2982
  ggml_set_name(out, "out");
2954
2983
 
2955
2984
  return out;
@@ -2961,21 +2990,23 @@ struct test_upscale_ext : public test_case {
2961
2990
  const ggml_type type;
2962
2991
  const std::array<int64_t, 4> ne;
2963
2992
  const std::array<int64_t, 4> ne_tgt;
2993
+ const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
2964
2994
 
2965
2995
  std::string vars() override {
2966
- return VARS_TO_STR3(type, ne, ne_tgt);
2996
+ return VARS_TO_STR4(type, ne, ne_tgt, mode);
2967
2997
  }
2968
2998
 
2969
2999
  test_upscale_ext(ggml_type type = GGML_TYPE_F32,
2970
3000
  std::array<int64_t, 4> ne = {2, 5, 7, 11},
2971
- std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
2972
- : type(type), ne(ne), ne_tgt(ne_tgt) {}
3001
+ std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
3002
+ ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
3003
+ : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
2973
3004
 
2974
3005
  ggml_tensor * build_graph(ggml_context * ctx) override {
2975
3006
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2976
3007
  ggml_set_name(a, "a");
2977
3008
 
2978
- ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
3009
+ ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
2979
3010
  ggml_set_name(out, "out");
2980
3011
 
2981
3012
  return out;
@@ -3199,7 +3230,8 @@ struct test_leaky_relu : public test_case {
3199
3230
 
3200
3231
  // GGML_OP_FLASH_ATTN_EXT
3201
3232
  struct test_flash_attn_ext : public test_case {
3202
- const int64_t hs; // head size
3233
+ const int64_t hsk; // K head size
3234
+ const int64_t hsv; // V head size
3203
3235
  const int64_t nh; // num heads
3204
3236
  const int64_t nr; // repeat in Q, tests for grouped-query attention
3205
3237
  const int64_t kv; // kv size
@@ -3215,7 +3247,7 @@ struct test_flash_attn_ext : public test_case {
3215
3247
  std::array<int32_t, 4> permute;
3216
3248
 
3217
3249
  std::string vars() override {
3218
- return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3250
+ return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3219
3251
  }
3220
3252
 
3221
3253
  double max_nmse_err() override {
@@ -3225,17 +3257,18 @@ struct test_flash_attn_ext : public test_case {
3225
3257
  uint64_t op_flops(ggml_tensor * t) override {
3226
3258
  GGML_UNUSED(t);
3227
3259
  // Just counting matmul costs:
3228
- // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
3229
- return 2 * 2 * nh*nr * nb * hs * kv;
3260
+ // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
3261
+ return 2 * nh*nr * nb * (hsk + hsv) * kv;
3230
3262
  }
3231
3263
 
3232
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3264
+ test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3233
3265
  bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
3234
3266
  ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
3235
- : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3267
+ : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3236
3268
 
3237
3269
  ggml_tensor * build_graph(ggml_context * ctx) override {
3238
- const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
3270
+ const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
3271
+ const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
3239
3272
 
3240
3273
  auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
3241
3274
  int64_t ne[4] = {ne0, ne1, ne2, ne3};
@@ -3250,13 +3283,13 @@ struct test_flash_attn_ext : public test_case {
3250
3283
  return t;
3251
3284
  };
3252
3285
 
3253
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
3286
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
3254
3287
  ggml_set_name(q, "q");
3255
3288
 
3256
- ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3289
+ ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
3257
3290
  ggml_set_name(k, "k");
3258
3291
 
3259
- ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3292
+ ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
3260
3293
  ggml_set_name(v, "v");
3261
3294
 
3262
3295
  ggml_tensor * m = nullptr;
@@ -3265,7 +3298,7 @@ struct test_flash_attn_ext : public test_case {
3265
3298
  ggml_set_name(m, "m");
3266
3299
  }
3267
3300
 
3268
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
3301
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
3269
3302
  ggml_flash_attn_ext_set_prec(out, prec);
3270
3303
  ggml_set_name(out, "out");
3271
3304
 
@@ -3995,14 +4028,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3995
4028
  test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
3996
4029
  }
3997
4030
 
3998
- for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
4031
+ // same-type copy
4032
+ for (ggml_type type : all_types) {
4033
+ const auto nk = ggml_blck_size(type);
4034
+
4035
+ for (int k = 1; k < 4; ++k) {
4036
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
4037
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
4038
+ test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
4039
+ }
4040
+ }
4041
+
4042
+ for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
3999
4043
  for (ggml_type type_dst : all_types) {
4000
4044
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
4001
4045
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
4002
4046
  }
4003
4047
  }
4004
- for (ggml_type type_dst : {GGML_TYPE_F32}) {
4005
- for (ggml_type type_src : all_types) {
4048
+ for (ggml_type type_src : all_types) {
4049
+ for (ggml_type type_dst : {GGML_TYPE_F32}) {
4006
4050
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
4007
4051
  test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
4008
4052
  }
@@ -4140,6 +4184,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4140
4184
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4141
4185
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4142
4186
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4187
+
4188
+ // test cases with large ne00/ne10 to cover stream-k fixup
4189
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
4190
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
4191
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
4143
4192
  }
4144
4193
  }
4145
4194
  for (ggml_type type_a : other_types) {
@@ -4175,6 +4224,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4175
4224
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
4176
4225
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
4177
4226
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
4227
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
4228
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
4229
+
4230
+ for (auto bs : {1,2,4,8}) {
4231
+ for (auto nr : {1,4}) {
4232
+ for (uint32_t m = 0; m < 2; ++m) {
4233
+ for (uint32_t k = 0; k < 2; ++k) {
4234
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
4235
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
4236
+ }
4237
+ }
4238
+ }
4239
+ }
4178
4240
 
4179
4241
  // sycl backend will limit task global_range < MAX_INT
4180
4242
  // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
@@ -4355,12 +4417,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4355
4417
  test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
4356
4418
  }
4357
4419
 
4420
+ for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
4421
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
4422
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
4423
+ test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
4424
+ }
4425
+
4358
4426
  test_cases.emplace_back(new test_sum());
4359
4427
  test_cases.emplace_back(new test_sum_rows());
4360
4428
  test_cases.emplace_back(new test_mean());
4361
- test_cases.emplace_back(new test_upscale());
4362
- test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
4363
- test_cases.emplace_back(new test_upscale_ext());
4364
4429
  test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
4365
4430
  test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
4366
4431
  test_cases.emplace_back(new test_acc());
@@ -4370,27 +4435,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4370
4435
  test_cases.emplace_back(new test_timestep_embedding());
4371
4436
  test_cases.emplace_back(new test_leaky_relu());
4372
4437
 
4373
- for (int hs : { 64, 80, 128, 256, }) {
4374
- for (bool mask : { true, false } ) {
4375
- for (float max_bias : { 0.0f, 8.0f }) {
4376
- if (!mask && max_bias > 0.0f) continue;
4377
- for (float logit_softcap : {0.0f, 10.0f}) {
4378
- if (hs != 128 && logit_softcap != 0.0f) continue;
4379
- for (int nh : { 4, }) {
4380
- for (int nr : { 1, 4, 16 }) {
4381
- if (nr == 16 && hs != 128) continue;
4382
- for (int kv : { 512, 1024, }) {
4383
- if (nr != 1 && kv != 512) continue;
4384
- for (int nb : { 1, 3, 32, 35, }) {
4385
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4386
- if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
4387
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4388
- test_cases.emplace_back(new test_flash_attn_ext(
4389
- hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4390
- // run fewer test cases permuted
4391
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4438
+ for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
4439
+ for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
4440
+ if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
4441
+ if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
4442
+ if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
4443
+
4444
+ for (bool mask : { true, false } ) {
4445
+ for (float max_bias : { 0.0f, 8.0f }) {
4446
+ if (!mask && max_bias > 0.0f) continue;
4447
+ for (float logit_softcap : {0.0f, 10.0f}) {
4448
+ if (hsk != 128 && logit_softcap != 0.0f) continue;
4449
+ for (int nh : { 4, }) {
4450
+ for (int nr : { 1, 4, 16 }) {
4451
+ if (nr == 16 && hsk != 128) continue;
4452
+ for (int kv : { 512, 1024, }) {
4453
+ if (nr != 1 && kv != 512) continue;
4454
+ for (int nb : { 1, 3, 32, 35, }) {
4455
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4456
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
4457
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4392
4458
  test_cases.emplace_back(new test_flash_attn_ext(
4393
- hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4459
+ hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4460
+ // run fewer test cases permuted
4461
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4462
+ test_cases.emplace_back(new test_flash_attn_ext(
4463
+ hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4464
+ }
4394
4465
  }
4395
4466
  }
4396
4467
  }
@@ -4444,6 +4515,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4444
4515
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
4445
4516
  test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
4446
4517
 
4518
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
4519
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
4520
+
4447
4521
  for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
4448
4522
  for (ggml_type type_a : all_types) {
4449
4523
  for (ggml_type type_b : {GGML_TYPE_F32}) {
@@ -4464,6 +4538,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4464
4538
  }
4465
4539
  }
4466
4540
 
4541
+ for (int kv : { 4096, 8192, 16384, }) {
4542
+ for (int hs : { 64, 128, }) {
4543
+ for (int nr : { 1, 4, }) {
4544
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
4545
+ }
4546
+ }
4547
+ }
4548
+
4467
4549
  return test_cases;
4468
4550
  }
4469
4551
 
@@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) {
19
19
  #endif
20
20
  }
21
21
 
22
+ #define U8C(x) (const char*)(u8##x)
23
+
22
24
  static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
23
25
  common_chat_msg msg;
24
26
  msg.role = role;
@@ -35,6 +37,8 @@ int main(void) {
35
37
  {"assistant", " I am an assistant "},
36
38
  {"user", "Another question"},
37
39
  };
40
+
41
+ // std::string wrong = /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
38
42
  struct TestCase {
39
43
  std::string name;
40
44
  std::string template_str;
@@ -177,24 +181,25 @@ int main(void) {
177
181
  },
178
182
  {
179
183
  /* .name= */ "ChatGLM4",
180
- /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
184
+ /* .template_str= */ U8C("[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
181
185
  /* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
182
186
  /* .expected_output_jinja= */ "",
183
187
  /* .bos_token= */ "",
184
188
  /* .eos_token= */ "",
185
189
  },
186
- {
187
- /* .name= */ "GLMEdge",
188
- /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
189
- /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
190
- /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
191
- /* .bos_token= */ "",
192
- /* .eos_token= */ "",
193
- },
190
+ // TODO @ngxson : GLMEdge produces poor result without `[gMASK]<sop>`, so we're temporarily using GLM4 template for it. We should fix this in the future.
191
+ // {
192
+ // /* .name= */ "GLMEdge",
193
+ // /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
194
+ // /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
195
+ // /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
196
+ // /* .bos_token= */ "",
197
+ // /* .eos_token= */ "",
198
+ // },
194
199
  {
195
200
  /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
196
- /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
197
- /* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
201
+ /* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
202
+ /* .expected_output= */ U8C("You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>"),
198
203
  /* .expected_output_jinja= */ "",
199
204
  /* .bos_token= */ "",
200
205
  /* .eos_token= */ "",
@@ -202,7 +207,7 @@ int main(void) {
202
207
  {
203
208
  /* .name= */ "DeepSeek-V2",
204
209
  /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
205
- /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
210
+ /* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
206
211
  /* .expected_output_jinja= */ "",
207
212
  /* .bos_token= */ "",
208
213
  /* .eos_token= */ "<|end▁of▁sentence|>",
@@ -256,7 +261,7 @@ int main(void) {
256
261
  },
257
262
  {
258
263
  /* .name= */ "Infinigence/Megrez-3B-Instruct",
259
- /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
264
+ /* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
260
265
  /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
261
266
  /* .expected_output_jinja= */ "",
262
267
  /* .bos_token= */ "",
@@ -270,6 +275,22 @@ int main(void) {
270
275
  /* .bos_token= */ "",
271
276
  /* .eos_token= */ "",
272
277
  },
278
+ {
279
+ /* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
280
+ /* .template_str= */ "<s>{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n {%- set name = tool.function.name %}\n {%- set description = tool.function.description|default('') %}\n {%- set parameters = tool.function.parameters|tojson %}\n {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n {{- tools_prefix }}\n {%- for tool in tools %}\n {{- __render_tool(tool) }}\n {%- endfor %}\n {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n {{- names.assistant }}\n {%- set call = message['function_call'] %}\n {%- if call %}\n {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n {%- else %}\n {{- ' ' + message.content + '\\n\\n' }}\n {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {{- __render_user_message(message) }}\n {%- endif %}\n {%- if message.role == 'assistant' and not loop.last %}\n {{- __render_assistant_message(message) }}\n {%- endif %}\n {%- if message.role == 'tool' %}\n {{- __render_tool_message(message) }}\n {%- endif %}\n {%- if loop.last %}\n {{- ' Ассистент:[SEP]' }}\n {%- endif %}\n{%- endfor %}\n",
281
+ /* .expected_output= */ "<s> Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
282
+ /* .expected_output_jinja= */ "<s> Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
283
+ /* .bos_token= */ "",
284
+ /* .eos_token= */ "",
285
+ },
286
+ {
287
+ /* .name= */ "inclusionAI/Ling-lite",
288
+ /* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
289
+ /* .expected_output= */ "<role>SYSTEM</role>You are a helpful assistant<role>HUMAN</role>Hello<role>ASSISTANT</role>Hi there<role>HUMAN</role>Who are you<role>ASSISTANT</role> I am an assistant <role>HUMAN</role>Another question<role>ASSISTANT</role>",
290
+ /* .expected_output_jinja= */ "",
291
+ /* .bos_token= */ "",
292
+ /* .eos_token= */ "",
293
+ },
273
294
  };
274
295
  std::vector<char> formatted_chat(1024);
275
296
  int32_t res;
@@ -11,8 +11,9 @@
11
11
  #include <string>
12
12
 
13
13
  #include "chat.h"
14
- #include "llama-grammar.h"
15
- #include "unicode.h"
14
+
15
+ #include "../src/unicode.h"
16
+ #include "../src/llama-grammar.h"
16
17
 
17
18
  using json = nlohmann::ordered_json;
18
19
 
@@ -569,6 +570,7 @@ static void test_template_output_parsers() {
569
570
  {
570
571
  // Not supported yet
571
572
  auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
573
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
572
574
  assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
573
575
  }
574
576
  {
@@ -665,6 +667,7 @@ static void test_template_output_parsers() {
665
667
  auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
666
668
  std::vector<std::string> end_tokens{ "<|im_end|>" };
667
669
 
670
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
668
671
  assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
669
672
  assert_equals(
670
673
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
@@ -793,6 +796,7 @@ static void test_template_output_parsers() {
793
796
  auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
794
797
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
795
798
 
799
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
796
800
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
797
801
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
798
802
  common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
@@ -815,6 +819,7 @@ static void test_template_output_parsers() {
815
819
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
816
820
 
817
821
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
822
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
818
823
 
819
824
  test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
820
825
  test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -824,6 +829,8 @@ static void test_template_output_parsers() {
824
829
  auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
825
830
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
826
831
 
832
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
833
+ common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
827
834
  assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
828
835
  common_chat_templates_apply(tmpls.get(), inputs_tools).format);
829
836
 
@@ -851,6 +858,7 @@ static void test_template_output_parsers() {
851
858
  auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
852
859
  std::vector<std::string> end_tokens{ "<|eot_id|>" };
853
860
 
861
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
854
862
  assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
855
863
 
856
864
  test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
@@ -862,6 +870,7 @@ static void test_template_output_parsers() {
862
870
  auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
863
871
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
864
872
 
873
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
865
874
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
866
875
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
867
876
 
@@ -891,6 +900,7 @@ static void test_template_output_parsers() {
891
900
  auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
892
901
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
893
902
 
903
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
894
904
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
895
905
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
896
906
 
@@ -1,5 +1,5 @@
1
- #include "unicode.h"
2
- #include "llama-grammar.h"
1
+ #include "../src/unicode.h"
2
+ #include "../src/llama-grammar.h"
3
3
 
4
4
  #include <cstdio>
5
5
  #include <cstdlib>
@@ -2,10 +2,11 @@
2
2
  #undef NDEBUG
3
3
  #endif
4
4
 
5
- #include "unicode.h"
6
- #include "llama-grammar.h"
7
5
  #include "json-schema-to-grammar.h"
8
6
 
7
+ #include "../src/unicode.h"
8
+ #include "../src/llama-grammar.h"
9
+
9
10
  #include <cassert>
10
11
  #include <string>
11
12
  #include <vector>