@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
271
271
  }
272
272
  }
273
273
 
274
+ static std::string var_to_str(ggml_scale_mode mode) {
275
+ switch (mode) {
276
+ case GGML_SCALE_MODE_NEAREST: return "nearest";
277
+ case GGML_SCALE_MODE_BILINEAR: return "bilinear";
278
+ default: return std::to_string(mode);
279
+ }
280
+ }
281
+
274
282
  #define VAR_TO_STR(x) (#x "=" + var_to_str(x))
275
283
 
276
284
  #define VARS_TO_STR1(a) VAR_TO_STR(a)
@@ -2063,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
2063
2071
  const ggml_type type_b;
2064
2072
  const int n_mats;
2065
2073
  const int n_used;
2066
- const bool b; // brodcast b matrix
2074
+ const bool b; // broadcast b matrix
2067
2075
  const int64_t m;
2068
2076
  const int64_t n;
2069
2077
  const int64_t k;
@@ -2598,6 +2606,8 @@ struct test_rope : public test_case {
2598
2606
  } else {
2599
2607
  out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2600
2608
  }
2609
+
2610
+ // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
2601
2611
  }
2602
2612
  ggml_set_name(out, "out");
2603
2613
 
@@ -2948,15 +2958,16 @@ struct test_upscale : public test_case {
2948
2958
  const std::array<int64_t, 4> ne;
2949
2959
  const int32_t scale_factor;
2950
2960
  const bool transpose;
2961
+ const ggml_scale_mode mode;
2951
2962
 
2952
2963
  std::string vars() override {
2953
- return VARS_TO_STR4(type, ne, scale_factor, transpose);
2964
+ return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
2954
2965
  }
2955
2966
 
2956
2967
  test_upscale(ggml_type type = GGML_TYPE_F32,
2957
2968
  std::array<int64_t, 4> ne = {512, 512, 3, 1},
2958
- int32_t scale_factor = 2, bool transpose = false)
2959
- : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
2969
+ int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
2970
+ : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
2960
2971
 
2961
2972
  ggml_tensor * build_graph(ggml_context * ctx) override {
2962
2973
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -2967,7 +2978,7 @@ struct test_upscale : public test_case {
2967
2978
  ggml_set_name(a, "a_transposed");
2968
2979
  }
2969
2980
 
2970
- ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
2981
+ ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
2971
2982
  ggml_set_name(out, "out");
2972
2983
 
2973
2984
  return out;
@@ -2979,21 +2990,23 @@ struct test_upscale_ext : public test_case {
2979
2990
  const ggml_type type;
2980
2991
  const std::array<int64_t, 4> ne;
2981
2992
  const std::array<int64_t, 4> ne_tgt;
2993
+ const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
2982
2994
 
2983
2995
  std::string vars() override {
2984
- return VARS_TO_STR3(type, ne, ne_tgt);
2996
+ return VARS_TO_STR4(type, ne, ne_tgt, mode);
2985
2997
  }
2986
2998
 
2987
2999
  test_upscale_ext(ggml_type type = GGML_TYPE_F32,
2988
3000
  std::array<int64_t, 4> ne = {2, 5, 7, 11},
2989
- std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
2990
- : type(type), ne(ne), ne_tgt(ne_tgt) {}
3001
+ std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
3002
+ ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
3003
+ : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
2991
3004
 
2992
3005
  ggml_tensor * build_graph(ggml_context * ctx) override {
2993
3006
  ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2994
3007
  ggml_set_name(a, "a");
2995
3008
 
2996
- ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
3009
+ ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
2997
3010
  ggml_set_name(out, "out");
2998
3011
 
2999
3012
  return out;
@@ -3217,7 +3230,8 @@ struct test_leaky_relu : public test_case {
3217
3230
 
3218
3231
  // GGML_OP_FLASH_ATTN_EXT
3219
3232
  struct test_flash_attn_ext : public test_case {
3220
- const int64_t hs; // head size
3233
+ const int64_t hsk; // K head size
3234
+ const int64_t hsv; // V head size
3221
3235
  const int64_t nh; // num heads
3222
3236
  const int64_t nr; // repeat in Q, tests for grouped-query attention
3223
3237
  const int64_t kv; // kv size
@@ -3233,7 +3247,7 @@ struct test_flash_attn_ext : public test_case {
3233
3247
  std::array<int32_t, 4> permute;
3234
3248
 
3235
3249
  std::string vars() override {
3236
- return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3250
+ return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
3237
3251
  }
3238
3252
 
3239
3253
  double max_nmse_err() override {
@@ -3243,17 +3257,18 @@ struct test_flash_attn_ext : public test_case {
3243
3257
  uint64_t op_flops(ggml_tensor * t) override {
3244
3258
  GGML_UNUSED(t);
3245
3259
  // Just counting matmul costs:
3246
- // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
3247
- return 2 * 2 * nh*nr * nb * hs * kv;
3260
+ // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
3261
+ return 2 * nh*nr * nb * (hsk + hsv) * kv;
3248
3262
  }
3249
3263
 
3250
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3264
+ test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
3251
3265
  bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
3252
3266
  ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
3253
- : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3267
+ : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
3254
3268
 
3255
3269
  ggml_tensor * build_graph(ggml_context * ctx) override {
3256
- const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
3270
+ const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
3271
+ const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
3257
3272
 
3258
3273
  auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
3259
3274
  int64_t ne[4] = {ne0, ne1, ne2, ne3};
@@ -3268,13 +3283,13 @@ struct test_flash_attn_ext : public test_case {
3268
3283
  return t;
3269
3284
  };
3270
3285
 
3271
- ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
3286
+ ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1);
3272
3287
  ggml_set_name(q, "q");
3273
3288
 
3274
- ggml_tensor * k = create_permuted(type_KV, hs_padded, kv, nh, 1);
3289
+ ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1);
3275
3290
  ggml_set_name(k, "k");
3276
3291
 
3277
- ggml_tensor * v = create_permuted(type_KV, hs_padded, kv, nh, 1);
3292
+ ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1);
3278
3293
  ggml_set_name(v, "v");
3279
3294
 
3280
3295
  ggml_tensor * m = nullptr;
@@ -3283,7 +3298,7 @@ struct test_flash_attn_ext : public test_case {
3283
3298
  ggml_set_name(m, "m");
3284
3299
  }
3285
3300
 
3286
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
3301
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
3287
3302
  ggml_flash_attn_ext_set_prec(out, prec);
3288
3303
  ggml_set_name(out, "out");
3289
3304
 
@@ -4169,6 +4184,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4169
4184
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
4170
4185
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
4171
4186
  test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
4187
+
4188
+ // test cases with large ne00/ne10 to cover stream-k fixup
4189
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
4190
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
4191
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
4172
4192
  }
4173
4193
  }
4174
4194
  for (ggml_type type_a : other_types) {
@@ -4204,6 +4224,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4204
4224
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
4205
4225
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
4206
4226
  test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
4227
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
4228
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
4207
4229
 
4208
4230
  for (auto bs : {1,2,4,8}) {
4209
4231
  for (auto nr : {1,4}) {
@@ -4395,12 +4417,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4395
4417
  test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
4396
4418
  }
4397
4419
 
4420
+ for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
4421
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
4422
+ test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
4423
+ test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
4424
+ }
4425
+
4398
4426
  test_cases.emplace_back(new test_sum());
4399
4427
  test_cases.emplace_back(new test_sum_rows());
4400
4428
  test_cases.emplace_back(new test_mean());
4401
- test_cases.emplace_back(new test_upscale());
4402
- test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
4403
- test_cases.emplace_back(new test_upscale_ext());
4404
4429
  test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
4405
4430
  test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
4406
4431
  test_cases.emplace_back(new test_acc());
@@ -4410,27 +4435,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4410
4435
  test_cases.emplace_back(new test_timestep_embedding());
4411
4436
  test_cases.emplace_back(new test_leaky_relu());
4412
4437
 
4413
- for (int hs : { 64, 80, 128, 256, }) {
4414
- for (bool mask : { true, false } ) {
4415
- for (float max_bias : { 0.0f, 8.0f }) {
4416
- if (!mask && max_bias > 0.0f) continue;
4417
- for (float logit_softcap : {0.0f, 10.0f}) {
4418
- if (hs != 128 && logit_softcap != 0.0f) continue;
4419
- for (int nh : { 4, }) {
4420
- for (int nr : { 1, 4, 16 }) {
4421
- if (nr == 16 && hs != 128) continue;
4422
- for (int kv : { 512, 1024, }) {
4423
- if (nr != 1 && kv != 512) continue;
4424
- for (int nb : { 1, 3, 32, 35, }) {
4425
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4426
- if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
4427
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4428
- test_cases.emplace_back(new test_flash_attn_ext(
4429
- hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4430
- // run fewer test cases permuted
4431
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4438
+ for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
4439
+ for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
4440
+ if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
4441
+ if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
4442
+ if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
4443
+
4444
+ for (bool mask : { true, false } ) {
4445
+ for (float max_bias : { 0.0f, 8.0f }) {
4446
+ if (!mask && max_bias > 0.0f) continue;
4447
+ for (float logit_softcap : {0.0f, 10.0f}) {
4448
+ if (hsk != 128 && logit_softcap != 0.0f) continue;
4449
+ for (int nh : { 4, }) {
4450
+ for (int nr : { 1, 4, 16 }) {
4451
+ if (nr == 16 && hsk != 128) continue;
4452
+ for (int kv : { 512, 1024, }) {
4453
+ if (nr != 1 && kv != 512) continue;
4454
+ for (int nb : { 1, 3, 32, 35, }) {
4455
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
4456
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
4457
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
4432
4458
  test_cases.emplace_back(new test_flash_attn_ext(
4433
- hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4459
+ hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
4460
+ // run fewer test cases permuted
4461
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
4462
+ test_cases.emplace_back(new test_flash_attn_ext(
4463
+ hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
4464
+ }
4434
4465
  }
4435
4466
  }
4436
4467
  }
@@ -4507,6 +4538,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4507
4538
  }
4508
4539
  }
4509
4540
 
4541
+ for (int kv : { 4096, 8192, 16384, }) {
4542
+ for (int hs : { 64, 128, }) {
4543
+ for (int nr : { 1, 4, }) {
4544
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
4545
+ }
4546
+ }
4547
+ }
4548
+
4510
4549
  return test_cases;
4511
4550
  }
4512
4551
 
@@ -19,6 +19,8 @@ static std::string normalize_newlines(const std::string & s) {
19
19
  #endif
20
20
  }
21
21
 
22
+ #define U8C(x) (const char*)(u8##x)
23
+
22
24
  static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
23
25
  common_chat_msg msg;
24
26
  msg.role = role;
@@ -35,6 +37,8 @@ int main(void) {
35
37
  {"assistant", " I am an assistant "},
36
38
  {"user", "Another question"},
37
39
  };
40
+
41
+ // std::string wrong = /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}";
38
42
  struct TestCase {
39
43
  std::string name;
40
44
  std::string template_str;
@@ -177,24 +181,25 @@ int main(void) {
177
181
  },
178
182
  {
179
183
  /* .name= */ "ChatGLM4",
180
- /* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
184
+ /* .template_str= */ U8C("[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
181
185
  /* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
182
186
  /* .expected_output_jinja= */ "",
183
187
  /* .bos_token= */ "",
184
188
  /* .eos_token= */ "",
185
189
  },
186
- {
187
- /* .name= */ "GLMEdge",
188
- /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
189
- /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
190
- /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
191
- /* .bos_token= */ "",
192
- /* .eos_token= */ "",
193
- },
190
+ // TODO @ngxson : GLMEdge produces poor result without `[gMASK]<sop>`, so we're temporarily using GLM4 template for it. We should fix this in the future.
191
+ // {
192
+ // /* .name= */ "GLMEdge",
193
+ // /* .template_str= */ "{% for item in messages %}{% if item['role'] == 'system' %}<|system|>\n{{ item['content'] }}{% elif item['role'] == 'user' %}<|user|>\n{{ item['content'] }}{% elif item['role'] == 'assistant' %}<|assistant|>\n{{ item['content'] }}{% endif %}{% endfor %}<|assistant|>",
194
+ // /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
195
+ // /* .expected_output_jinja= */ "<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
196
+ // /* .bos_token= */ "",
197
+ // /* .eos_token= */ "",
198
+ // },
194
199
  {
195
200
  /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
196
- /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
197
- /* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
201
+ /* .template_str= */ U8C("{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"),
202
+ /* .expected_output= */ U8C("You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>"),
198
203
  /* .expected_output_jinja= */ "",
199
204
  /* .bos_token= */ "",
200
205
  /* .eos_token= */ "",
@@ -202,7 +207,7 @@ int main(void) {
202
207
  {
203
208
  /* .name= */ "DeepSeek-V2",
204
209
  /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
205
- /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
210
+ /* .expected_output= */ U8C("You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:"),
206
211
  /* .expected_output_jinja= */ "",
207
212
  /* .bos_token= */ "",
208
213
  /* .eos_token= */ "<|end▁of▁sentence|>",
@@ -256,7 +261,7 @@ int main(void) {
256
261
  },
257
262
  {
258
263
  /* .name= */ "Infinigence/Megrez-3B-Instruct",
259
- /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
264
+ /* .template_str= */ U8C("{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"),
260
265
  /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
261
266
  /* .expected_output_jinja= */ "",
262
267
  /* .bos_token= */ "",
@@ -270,6 +275,22 @@ int main(void) {
270
275
  /* .bos_token= */ "",
271
276
  /* .eos_token= */ "",
272
277
  },
278
+ {
279
+ /* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
280
+ /* .template_str= */ "<s>{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n {%- set name = tool.function.name %}\n {%- set description = tool.function.description|default('') %}\n {%- set parameters = tool.function.parameters|tojson %}\n {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n {{- tools_prefix }}\n {%- for tool in tools %}\n {{- __render_tool(tool) }}\n {%- endfor %}\n {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n {{- names.assistant }}\n {%- set call = message['function_call'] %}\n {%- if call %}\n {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n {%- else %}\n {{- ' ' + message.content + '\\n\\n' }}\n {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {{- __render_user_message(message) }}\n {%- endif %}\n {%- if message.role == 'assistant' and not loop.last %}\n {{- __render_assistant_message(message) }}\n {%- endif %}\n {%- if message.role == 'tool' %}\n {{- __render_tool_message(message) }}\n {%- endif %}\n {%- if loop.last %}\n {{- ' Ассистент:[SEP]' }}\n {%- endif %}\n{%- endfor %}\n",
281
+ /* .expected_output= */ "<s> Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
282
+ /* .expected_output_jinja= */ "<s> Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент: I am an assistant \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
283
+ /* .bos_token= */ "",
284
+ /* .eos_token= */ "",
285
+ },
286
+ {
287
+ /* .name= */ "inclusionAI/Ling-lite",
288
+ /* .template_str */ "{% for message in messages %}{% set role = message['role'] | lower %}{% if role == 'user' %}{% set role = 'HUMAN' %}{% endif %}{% set role = role | upper %}{{ '<role>' + role + '</role>' + message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '<role>ASSISTANT</role>' }}{% endif %}",
289
+ /* .expected_output= */ "<role>SYSTEM</role>You are a helpful assistant<role>HUMAN</role>Hello<role>ASSISTANT</role>Hi there<role>HUMAN</role>Who are you<role>ASSISTANT</role> I am an assistant <role>HUMAN</role>Another question<role>ASSISTANT</role>",
290
+ /* .expected_output_jinja= */ "",
291
+ /* .bos_token= */ "",
292
+ /* .eos_token= */ "",
293
+ },
273
294
  };
274
295
  std::vector<char> formatted_chat(1024);
275
296
  int32_t res;
@@ -11,8 +11,9 @@
11
11
  #include <string>
12
12
 
13
13
  #include "chat.h"
14
- #include "llama-grammar.h"
15
- #include "unicode.h"
14
+
15
+ #include "../src/unicode.h"
16
+ #include "../src/llama-grammar.h"
16
17
 
17
18
  using json = nlohmann::ordered_json;
18
19
 
@@ -569,6 +570,7 @@ static void test_template_output_parsers() {
569
570
  {
570
571
  // Not supported yet
571
572
  auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
573
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
572
574
  assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
573
575
  }
574
576
  {
@@ -665,6 +667,7 @@ static void test_template_output_parsers() {
665
667
  auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
666
668
  std::vector<std::string> end_tokens{ "<|im_end|>" };
667
669
 
670
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
668
671
  assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
669
672
  assert_equals(
670
673
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
@@ -793,6 +796,7 @@ static void test_template_output_parsers() {
793
796
  auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
794
797
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
795
798
 
799
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
796
800
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
797
801
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
798
802
  common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
@@ -815,6 +819,7 @@ static void test_template_output_parsers() {
815
819
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
816
820
 
817
821
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
822
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
818
823
 
819
824
  test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
820
825
  test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -824,6 +829,8 @@ static void test_template_output_parsers() {
824
829
  auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
825
830
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
826
831
 
832
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
833
+ common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
827
834
  assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
828
835
  common_chat_templates_apply(tmpls.get(), inputs_tools).format);
829
836
 
@@ -851,6 +858,7 @@ static void test_template_output_parsers() {
851
858
  auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
852
859
  std::vector<std::string> end_tokens{ "<|eot_id|>" };
853
860
 
861
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
854
862
  assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
855
863
 
856
864
  test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
@@ -862,6 +870,7 @@ static void test_template_output_parsers() {
862
870
  auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
863
871
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
864
872
 
873
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
865
874
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
866
875
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
867
876
 
@@ -891,6 +900,7 @@ static void test_template_output_parsers() {
891
900
  auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
892
901
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
893
902
 
903
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
894
904
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
895
905
  assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
896
906
 
@@ -1,5 +1,5 @@
1
- #include "unicode.h"
2
- #include "llama-grammar.h"
1
+ #include "../src/unicode.h"
2
+ #include "../src/llama-grammar.h"
3
3
 
4
4
  #include <cstdio>
5
5
  #include <cstdlib>
@@ -2,10 +2,11 @@
2
2
  #undef NDEBUG
3
3
  #endif
4
4
 
5
- #include "unicode.h"
6
- #include "llama-grammar.h"
7
5
  #include "json-schema-to-grammar.h"
8
6
 
7
+ #include "../src/unicode.h"
8
+ #include "../src/llama-grammar.h"
9
+
9
10
  #include <cassert>
10
11
  #include <string>
11
12
  #include <vector>
@@ -2,7 +2,6 @@
2
2
  # undef NDEBUG
3
3
  #endif
4
4
 
5
- #include "unicode.h"
6
5
  #include "sampling.h"
7
6
 
8
7
  #include <cassert>
@@ -84,7 +83,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
84
83
 
85
84
  fprintf(stderr,
86
85
  "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
87
- "command: ./llama-gbnf-validator test-grammar-integration.grammar.gbnf "
86
+ "command: ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
88
87
  "test-grammar-integration.string.txt\n\n");
89
88
  } else {
90
89
  fprintf(stdout, "✅︎\n");
@@ -1086,6 +1085,65 @@ static void test_json_schema() {
1086
1085
  });
1087
1086
  }
1088
1087
 
1088
+ static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
1089
+ auto n_vocab = tok_arr.size;
1090
+
1091
+ tok_arr.selected = -1;
1092
+ tok_arr.sorted = false;
1093
+ for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1094
+ tok_arr.data[token_id].id = token_id;
1095
+ tok_arr.data[token_id].logit = 0.0f;
1096
+ }
1097
+
1098
+ tok_arr.data[selected].logit = 100.0f;
1099
+ }
1100
+
1101
+ static void test_sampler_chain(void) {
1102
+ auto sparams = llama_sampler_chain_default_params();
1103
+ sparams.no_perf = false;
1104
+ llama_sampler * sampler = llama_sampler_chain_init(sparams);
1105
+
1106
+ const auto grammar_data = R"(%llguidance {}
1107
+ start: /[A-Z ]*/)";
1108
+
1109
+ llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
1110
+ llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
1111
+
1112
+ auto input = "ALL YOUR BASE ARE BELONG TO US";
1113
+ auto tokens = common_tokenize(vocab, input, false, false);
1114
+
1115
+ auto n_vocab = llama_vocab_n_tokens(vocab);
1116
+
1117
+ std::vector<llama_token_data> cur;
1118
+ cur.reserve(n_vocab);
1119
+ for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
1120
+ cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
1121
+ }
1122
+ auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
1123
+
1124
+ for (const auto token : tokens) {
1125
+ one_hot(tok_arr, token);
1126
+
1127
+ fprintf(stderr, "applying token: %d\n", token);
1128
+ llama_sampler_apply(sampler, &tok_arr);
1129
+
1130
+ auto idx = tok_arr.selected;
1131
+ fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
1132
+ assert(cur[tok_arr.selected].id == token);
1133
+ llama_sampler_accept(sampler, token);
1134
+ }
1135
+
1136
+ auto tok_eos = llama_vocab_eot(vocab);
1137
+ if (tok_eos == LLAMA_TOKEN_NULL) {
1138
+ tok_eos = llama_vocab_eos(vocab);
1139
+ }
1140
+
1141
+ one_hot(tok_arr, tok_eos);
1142
+
1143
+ llama_sampler_apply(sampler, &tok_arr);
1144
+ assert(cur[tok_arr.selected].id == tok_eos);
1145
+ }
1146
+
1089
1147
  int main(int argc, const char ** argv) {
1090
1148
  fprintf(stdout, "Running llguidance integration tests...\n");
1091
1149
 
@@ -1135,6 +1193,9 @@ int main(int argc, const char ** argv) {
1135
1193
  test_special_chars();
1136
1194
  test_quantifiers();
1137
1195
  test_json_schema();
1196
+
1197
+ test_sampler_chain();
1198
+
1138
1199
  fprintf(stdout, "All tests passed.\n");
1139
1200
  return 0;
1140
1201
  }
@@ -3,7 +3,9 @@
3
3
  #endif
4
4
 
5
5
  #include "llama.h"
6
- #include "llama-grammar.h"
6
+
7
+ // TODO: shold not include libllama sources
8
+ #include "../src/llama-grammar.h"
7
9
 
8
10
  #include <cassert>
9
11
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  #include "json-schema-to-grammar.h"
6
6
 
7
- #include "llama-grammar.h"
7
+ #include "../src/llama-grammar.h"
8
8
 
9
9
  #include <cassert>
10
10
  #include <fstream>
@@ -597,6 +597,22 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
597
597
  )"""
598
598
  });
599
599
 
600
+ test({
601
+ SUCCESS,
602
+ "maxItems 0",
603
+ R"""({
604
+ "items": {
605
+ "type": "boolean"
606
+ },
607
+ "maxItems": 0
608
+ })""",
609
+ R"""(
610
+ boolean ::= ("true" | "false") space
611
+ root ::= "[" space "]" space
612
+ space ::= | " " | "\n"{1,2} [ \t]{0,20}
613
+ )"""
614
+ });
615
+
600
616
  test({
601
617
  SUCCESS,
602
618
  "maxItems 1",
@@ -3,7 +3,8 @@
3
3
  #endif
4
4
 
5
5
  #include "llama.h"
6
- #include "llama-grammar.h"
6
+
7
+ #include "../src/llama-grammar.h"
7
8
 
8
9
  #include <cassert>
9
10
  #include <stdexcept>
@@ -1,8 +1,10 @@
1
1
  #include "ggml.h"
2
+ #include "ggml-cpu.h"
2
3
  #include "llama.h"
3
- #include "llama-model.h"
4
4
  #include "common.h"
5
5
 
6
+ #include "../src/llama-model.h"
7
+
6
8
  #include <algorithm>
7
9
  #include <cassert>
8
10
  #include <cinttypes>
@@ -1,8 +1,9 @@
1
1
  #include "llama.h"
2
2
  #include "common.h"
3
- #include "unicode.h"
4
3
  #include "console.h"
5
4
 
5
+ #include "../src/unicode.h"
6
+
6
7
  #include <cassert>
7
8
  #include <codecvt>
8
9
  #include <cstdio>