@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -16,7 +16,6 @@
16
16
 
17
17
 
18
18
  #include <ggml.h>
19
- #include <ggml-cpu.h>
20
19
  #include <ggml-alloc.h>
21
20
  #include <ggml-backend.h>
22
21
 
@@ -26,7 +25,6 @@
26
25
  #include <cstdint>
27
26
  #include <cstring>
28
27
  #include <cinttypes>
29
- #include <functional>
30
28
  #include <memory>
31
29
  #include <random>
32
30
  #include <stdio.h>
@@ -639,19 +637,20 @@ struct test_case {
639
637
 
640
638
  // determine number of runs
641
639
  int n_runs;
640
+ bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
642
641
  if (op_flops(out) > 0) {
643
642
  // based on flops
644
643
  const uint64_t GFLOP = 1000 * 1000 * 1000;
645
644
  const uint64_t target_flops_cpu = 8ULL * GFLOP;
646
645
  const uint64_t target_flops_gpu = 100ULL * GFLOP;
647
- uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu;
646
+ uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
648
647
  n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
649
648
  } else {
650
649
  // based on memory size
651
650
  const size_t GB = 1ULL << 30;
652
651
  const size_t target_size_cpu = 8 * GB;
653
652
  const size_t target_size_gpu = 32 * GB;
654
- size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu;
653
+ size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
655
654
  n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
656
655
  }
657
656
 
@@ -819,7 +818,6 @@ struct test_case {
819
818
  }
820
819
  }
821
820
 
822
- // TODO: refactor so that this check is only needed once
823
821
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
824
822
  if (!ggml_backend_supports_op(backend, t)) {
825
823
  printf("not supported [%s] ", ggml_backend_name(backend));
@@ -1155,6 +1153,26 @@ struct test_argmax : public test_case {
1155
1153
  return out;
1156
1154
  }
1157
1155
 
1156
+ void initialize_tensors(ggml_context * ctx) override {
1157
+ std::random_device rd;
1158
+ std::default_random_engine rng(rd());
1159
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1160
+ if (t->type == GGML_TYPE_F32) {
1161
+ // initialize with unique values to avoid ties
1162
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
1163
+ std::vector<float> data(t->ne[0]);
1164
+ for (int i = 0; i < t->ne[0]; i++) {
1165
+ data[i] = i;
1166
+ }
1167
+ std::shuffle(data.begin(), data.end(), rng);
1168
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
1169
+ }
1170
+ } else {
1171
+ init_tensor_uniform(t);
1172
+ }
1173
+ }
1174
+ }
1175
+
1158
1176
  double max_nmse_err() override {
1159
1177
  return 0.0;
1160
1178
  }
@@ -2183,7 +2201,15 @@ struct test_rope : public test_case {
2183
2201
  ggml_set_name(a, "a");
2184
2202
  }
2185
2203
 
2186
- ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2204
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
2205
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
2206
+
2207
+ ggml_tensor * pos;
2208
+ if (is_mrope || is_vision) {
2209
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
2210
+ } else {
2211
+ pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2212
+ }
2187
2213
  ggml_set_name(pos, "pos");
2188
2214
 
2189
2215
  ggml_tensor * freq = nullptr;
@@ -2192,7 +2218,20 @@ struct test_rope : public test_case {
2192
2218
  ggml_set_name(freq, "freq");
2193
2219
  }
2194
2220
 
2195
- ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2221
+ ggml_tensor * out;
2222
+ if (is_mrope) {
2223
+ if (is_vision) {
2224
+ GGML_ASSERT(n_dims/4 > 0);
2225
+ int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
2226
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2227
+ } else {
2228
+ GGML_ASSERT(n_dims/3 > 0);
2229
+ int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
2230
+ out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2231
+ }
2232
+ } else {
2233
+ out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2234
+ }
2196
2235
  ggml_set_name(out, "out");
2197
2236
 
2198
2237
  return out;
@@ -2202,11 +2241,12 @@ struct test_rope : public test_case {
2202
2241
  for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2203
2242
  if (t->type == GGML_TYPE_I32) {
2204
2243
  // pos
2205
- std::vector<int> data(ne_a[2]);
2206
- for (int i = 0; i < ne_a[2]; i++) {
2244
+ const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
2245
+ std::vector<int> data(num_pos_ids);
2246
+ for (int i = 0; i < num_pos_ids; i++) {
2207
2247
  data[i] = rand() % n_ctx;
2208
2248
  }
2209
- ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
2249
+ ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
2210
2250
  } else {
2211
2251
  if (t->ne[0] == n_dims/2) {
2212
2252
  // frequency factors in the range [0.9f, 1.1f]
@@ -2679,6 +2719,33 @@ struct test_pad : public test_case {
2679
2719
  }
2680
2720
  };
2681
2721
 
2722
+ // GGML_OP_PAD_REFLECT_1D
2723
+ struct test_pad_reflect_1d : public test_case {
2724
+ const ggml_type type;
2725
+ const std::array<int64_t, 4> ne_a;
2726
+ const int pad_0;
2727
+ const int pad_1;
2728
+
2729
+ std::string vars() override {
2730
+ return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
2731
+ }
2732
+
2733
+ test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
2734
+ std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
2735
+ int pad_0 = 10, int pad_1 = 9)
2736
+ : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
2737
+
2738
+ ggml_tensor * build_graph(ggml_context * ctx) override {
2739
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
2740
+ ggml_set_name(a, "a");
2741
+
2742
+ ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
2743
+ ggml_set_name(out, "out");
2744
+
2745
+ return out;
2746
+ }
2747
+ };
2748
+
2682
2749
  // GGML_OP_ARANGE
2683
2750
  struct test_arange : public test_case {
2684
2751
  const ggml_type type;
@@ -3316,7 +3383,9 @@ static const ggml_type all_types[] = {
3316
3383
 
3317
3384
  static const ggml_type base_types[] = {
3318
3385
  GGML_TYPE_F32, GGML_TYPE_F16,
3386
+ GGML_TYPE_Q8_0, // for I8MM tests
3319
3387
  GGML_TYPE_Q4_0,
3388
+ GGML_TYPE_Q4_1, // for I8MM tests
3320
3389
  GGML_TYPE_Q4_K,
3321
3390
  GGML_TYPE_IQ2_XXS
3322
3391
  };
@@ -3440,9 +3509,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3440
3509
  test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
3441
3510
  test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
3442
3511
 
3443
- test_cases.emplace_back(new test_argmax());
3444
3512
  test_cases.emplace_back(new test_count_equal());
3445
3513
 
3514
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
3515
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
3516
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3517
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
3518
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
3519
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
3520
+
3446
3521
  for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
3447
3522
  test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
3448
3523
  test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
@@ -3468,10 +3543,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3468
3543
  test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
3469
3544
  }
3470
3545
 
3546
+ for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
3547
+ test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
3548
+ }
3549
+
3471
3550
  for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
3472
3551
  for (ggml_type type_dst : all_types) {
3473
- test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3474
- test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3552
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
3553
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
3475
3554
  }
3476
3555
  }
3477
3556
  for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
@@ -3547,6 +3626,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3547
3626
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
3548
3627
  test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
3549
3628
 
3629
+ for (int i = 1; i < 9; ++i) {
3630
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3631
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3632
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3633
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3634
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3635
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3636
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3637
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3638
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3639
+ test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
3640
+ }
3641
+
3550
3642
  #if 1
3551
3643
  for (ggml_type type_a : base_types) {
3552
3644
  for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@@ -3743,6 +3835,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3743
3835
  test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
3744
3836
  }
3745
3837
 
3838
+ if (all) {
3839
+ test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 2B)
3840
+ test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl 7B)
3841
+ test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v)); // rope_multi,m-rope (qwen2vl ViT)
3842
+ }
3843
+
3746
3844
  test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
3747
3845
  }
3748
3846
  }
@@ -3773,9 +3871,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3773
3871
  test_cases.emplace_back(new test_upscale());
3774
3872
  test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
3775
3873
  test_cases.emplace_back(new test_upscale_ext());
3776
- test_cases.emplace_back(new test_group_norm());
3874
+ test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
3875
+ test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
3777
3876
  test_cases.emplace_back(new test_acc());
3778
3877
  test_cases.emplace_back(new test_pad());
3878
+ test_cases.emplace_back(new test_pad_reflect_1d());
3779
3879
  test_cases.emplace_back(new test_arange());
3780
3880
  test_cases.emplace_back(new test_timestep_embedding());
3781
3881
  test_cases.emplace_back(new test_leaky_relu());
@@ -3822,6 +3922,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3822
3922
  test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
3823
3923
 
3824
3924
  test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
3925
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
3926
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
3927
+
3928
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
3929
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
3930
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
3931
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
3932
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
3933
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
3934
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
3935
+
3936
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
3937
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
3938
+ test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
3825
3939
 
3826
3940
  for (int bs : {1, 512}) {
3827
3941
  for (ggml_type type_a : all_types) {
@@ -3837,7 +3951,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3837
3951
  static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
3838
3952
  if (mode == MODE_TEST) {
3839
3953
  auto test_cases = make_test_cases_eval();
3840
- ggml_backend_t backend_cpu = ggml_backend_cpu_init();
3954
+ ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
3955
+ if (backend_cpu == NULL) {
3956
+ printf(" Failed to initialize CPU backend\n");
3957
+ return false;
3958
+ }
3841
3959
 
3842
3960
  size_t n_ok = 0;
3843
3961
  for (auto & test : test_cases) {
@@ -3917,7 +4035,9 @@ int main(int argc, char ** argv) {
3917
4035
  }
3918
4036
  }
3919
4037
 
3920
- // enumerate backends
4038
+ // load and enumerate backends
4039
+ ggml_backend_load_all();
4040
+
3921
4041
  printf("Testing %zu devices\n\n", ggml_backend_dev_count());
3922
4042
 
3923
4043
  size_t n_ok = 0;
@@ -3933,16 +4053,15 @@ int main(int argc, char ** argv) {
3933
4053
  continue;
3934
4054
  }
3935
4055
 
3936
- ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
3937
- GGML_ASSERT(backend != NULL);
3938
-
3939
- if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) {
4056
+ if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
3940
4057
  printf(" Skipping CPU backend\n");
3941
- ggml_backend_free(backend);
3942
4058
  n_ok++;
3943
4059
  continue;
3944
4060
  }
3945
4061
 
4062
+ ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
4063
+ GGML_ASSERT(backend != NULL);
4064
+
3946
4065
  ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
3947
4066
  auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
3948
4067
  if (ggml_backend_set_n_threads_fn) {
@@ -21,7 +21,7 @@ int main(void) {
21
21
  std::vector<std::string> templates = {
22
22
  // teknium/OpenHermes-2.5-Mistral-7B
23
23
  "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
24
- // mistralai/Mistral-7B-Instruct-v0.2
24
+ // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
25
25
  "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
26
26
  // TheBloke/FusionNet_34Bx2_MoE-AWQ
27
27
  "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
@@ -67,16 +67,26 @@ int main(void) {
67
67
  "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
68
68
  // ibm-granite/granite-3.0-8b-instruct
69
69
  "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
70
+ // mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
71
+ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
72
+ // Mistral-Large-Instruct-2407 (mistralai 'v3' template)
73
+ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
74
+ // Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template)
75
+ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
76
+ // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
77
+ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
78
+ // ai-sage/GigaChat-20B-A3B-instruct
79
+ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
70
80
  };
71
81
  std::vector<std::string> expected_output = {
72
82
  // teknium/OpenHermes-2.5-Mistral-7B
73
83
  "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
74
- // mistralai/Mistral-7B-Instruct-v0.2
84
+ // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)
75
85
  "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
76
86
  // TheBloke/FusionNet_34Bx2_MoE-AWQ
77
- "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
87
+ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
78
88
  // bofenghuang/vigogne-2-70b-chat
79
- "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
89
+ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
80
90
  // mlabonne/AlphaMonarch-7B
81
91
  "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
82
92
  // google/gemma-7b-it
@@ -113,10 +123,31 @@ int main(void) {
113
123
  u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
114
124
  // ibm-granite/granite-3.0-8b-instruct
115
125
  "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
126
+ // mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)
127
+ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
128
+ // Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)
129
+ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
130
+ // Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)
131
+ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
132
+ // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)
133
+ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
134
+ // ai-sage/GigaChat-20B-A3B-instruct
135
+ "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
116
136
  };
117
137
  std::vector<char> formatted_chat(1024);
118
138
  int32_t res;
119
139
 
140
+ // list all supported templates
141
+ std::vector<const char *> supported_tmpl;
142
+ res = llama_chat_builtin_templates(nullptr, 0);
143
+ assert(res > 0);
144
+ supported_tmpl.resize(res);
145
+ res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
146
+ printf("Built-in chat templates:\n");
147
+ for (auto tmpl : supported_tmpl) {
148
+ printf(" %s\n", tmpl);
149
+ }
150
+
120
151
  // test invalid chat template
121
152
  res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
122
153
  assert(res < 0);
@@ -154,9 +185,16 @@ int main(void) {
154
185
  return output;
155
186
  };
156
187
  assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
188
+ assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
189
+ assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
190
+ assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
191
+ assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
157
192
  assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
193
+ assert(fmt_sys("llama2-sys") == "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\n");
194
+ assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
158
195
  assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
159
196
  assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
197
+ assert(fmt_sys("gigachat") == "<s>You are a helpful assistant<|message_sep|>");
160
198
 
161
199
 
162
200
  // test llama_chat_format_single for user message
@@ -173,9 +211,17 @@ int main(void) {
173
211
  return output;
174
212
  };
175
213
  assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
214
+ assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
215
+ assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
216
+ assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
217
+ assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]");
176
218
  assert(fmt_single("llama2") == "[INST] How are you [/INST]");
219
+ assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
177
220
  assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
178
221
  assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
222
+ assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
223
+
224
+ printf("Test chat templates: OK\n");
179
225
 
180
226
  return 0;
181
227
  }