@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,3 +1,4 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
2
3
  #include "llama.h"
3
4
 
@@ -9,25 +10,25 @@
9
10
  static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
10
11
  std::vector<std::vector<float>> result;
11
12
 
12
- const llama_model * mdl = llama_get_model(ctx);
13
+ const llama_model * model = llama_get_model(ctx);
13
14
 
14
15
  llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
15
16
 
16
17
  for (uint64_t i = 0; i < sentences.size(); i++) {
17
- llama_batch_clear(batch);
18
+ common_batch_clear(batch);
18
19
 
19
20
  const std::string input_string = instruction + sentences[i];
20
21
 
21
- std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
22
+ std::vector<llama_token> inputs = common_tokenize(model, input_string, true, false);
22
23
 
23
24
  const int32_t n_toks = inputs.size();
24
25
 
25
26
  // GritLM seems to have EOS = ""
26
27
  // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
27
- // inputs.push_back(llama_token_eos(mdl));
28
+ // inputs.push_back(llama_token_eos(model));
28
29
 
29
30
  // we want to ignore instruction tokens for mean pooling
30
- const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
31
+ const int32_t n_inst = common_tokenize(model, instruction, true, false).size();
31
32
 
32
33
  #ifdef GRIT_DEBUG
33
34
  // debug tokens - should be matching as referenced in the GritLM sample
@@ -39,7 +40,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
39
40
 
40
41
  // add input to batch (this increments n_tokens)
41
42
  for (int32_t j = 0; j < n_toks; j++) {
42
- llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
43
+ common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
43
44
  }
44
45
 
45
46
  // clear previous kv_cache values (irrelevant for embeddings)
@@ -51,7 +52,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
51
52
  llama_decode(ctx, batch);
52
53
 
53
54
  // get embedding dimensions
54
- uint64_t n_embd = llama_n_embd(mdl);
55
+ uint64_t n_embd = llama_n_embd(model);
55
56
 
56
57
  // allocate embedding output
57
58
  std::vector<float> emb_unorm(n_embd, 0.0f);
@@ -74,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
74
75
  }
75
76
 
76
77
  std::vector<float> emb_norm(emb_unorm.size());
77
- llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
78
+ common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
78
79
  result.push_back(emb_norm);
79
80
 
80
81
  #ifdef GRIT_DEBUG
@@ -92,11 +93,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
92
93
  return result;
93
94
  }
94
95
 
95
- static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
96
+ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
96
97
  std::string result;
97
98
 
98
- const llama_model * mdl = llama_get_model(ctx);
99
- llama_token eos_token = llama_token_eos(mdl);
99
+ const llama_model * model = llama_get_model(ctx);
100
+ llama_token eos_token = llama_token_eos(model);
100
101
 
101
102
  llama_kv_cache_clear(ctx);
102
103
  llama_set_embeddings(ctx, false);
@@ -104,33 +105,29 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
104
105
 
105
106
  llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
106
107
 
107
- std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
108
+ std::vector<llama_token> inputs = common_tokenize(model, prompt, false, true);
108
109
  int32_t i_current_token = 0;
109
110
 
110
111
  while (true) {
111
- llama_batch_clear(bat);
112
- auto n_inputs = (int32_t)inputs.size();
113
- for (int32_t i = 0; i < n_inputs; i++) {
114
- llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
112
+ common_batch_clear(bat);
113
+ {
114
+ const int32_t n_inputs = inputs.size();
115
+
116
+ for (int32_t i = 0; i < n_inputs; i++) {
117
+ common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
118
+ }
115
119
  }
116
120
  inputs.clear();
117
121
 
118
122
  llama_decode(ctx, bat);
119
- auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
120
123
 
121
- auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
122
- auto n_candidates = (int32_t)candidates.size();
123
- for (int32_t token = 0; token < n_candidates; token++) {
124
- candidates[token] = llama_token_data{ token, logits[token], 0.0f };
125
- }
126
- auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
124
+ llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
127
125
 
128
- llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
129
126
  if (token == eos_token) {
130
127
  break;
131
128
  }
132
129
 
133
- std::string piece = llama_token_to_piece(ctx, token);
130
+ std::string piece = common_token_to_piece(ctx, token);
134
131
  if (stream) {
135
132
  std::printf("%s", piece.c_str());
136
133
  std::fflush(stdout);
@@ -155,22 +152,31 @@ static std::string gritlm_instruction(const std::string & instruction) {
155
152
  }
156
153
 
157
154
  int main(int argc, char * argv[]) {
158
- gpt_params params;
155
+ common_params params;
159
156
 
160
- if (!gpt_params_parse(argc, argv, params)) {
161
- gpt_params_print_usage(argc, argv, params);
157
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
162
158
  return 1;
163
159
  }
164
160
 
165
- llama_model_params mparams = llama_model_params_from_gpt_params(params);
166
- llama_context_params cparams = llama_context_params_from_gpt_params(params);
161
+ common_init();
162
+
163
+ llama_model_params mparams = common_model_params_to_llama(params);
164
+ llama_context_params cparams = common_context_params_to_llama(params);
167
165
 
168
166
  llama_backend_init();
169
167
 
170
- llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
168
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
171
169
 
172
170
  // create generation context
173
- llama_context * ctx = llama_new_context_with_model(mdl, cparams);
171
+ llama_context * ctx = llama_new_context_with_model(model, cparams);
172
+
173
+ auto sparams = llama_sampler_chain_default_params();
174
+
175
+ sparams.no_perf = false;
176
+
177
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
178
+
179
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
174
180
 
175
181
  // ### Embedding/Representation ###
176
182
  // samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -191,12 +197,12 @@ int main(int argc, char * argv[]) {
191
197
  const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
192
198
  const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
193
199
 
194
- const int n_embd = llama_n_embd(mdl);
200
+ const int n_embd = llama_n_embd(model);
195
201
 
196
- const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
197
- const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
198
- const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
199
- const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
202
+ const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
203
+ const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
204
+ const float cosine_sim_q1_d0 = common_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
205
+ const float cosine_sim_q1_d1 = common_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
200
206
 
201
207
  std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
202
208
  std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
@@ -208,11 +214,12 @@ int main(int argc, char * argv[]) {
208
214
  // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
209
215
  {
210
216
  const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
211
- std::string response = generate(ctx, prompt, true);
217
+ std::string response = generate(ctx, smpl, prompt, true);
212
218
  }
213
219
 
220
+ llama_sampler_free(smpl);
214
221
  llama_free(ctx);
215
- llama_free_model(mdl);
222
+ llama_free_model(model);
216
223
  llama_backend_free();
217
224
 
218
225
  return 0;
@@ -1,4 +1,6 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
3
+ #include "log.h"
2
4
  #include "llama.h"
3
5
 
4
6
  #include <cmath>
@@ -17,15 +19,13 @@
17
19
  #pragma warning(disable: 4244 4267) // possible loss of data
18
20
  #endif
19
21
 
20
- static void print_usage(int argc, char ** argv, const gpt_params & params) {
21
- gpt_params_print_usage(argc, argv, params);
22
-
23
- LOG_TEE("\nexample usage:\n");
24
- LOG_TEE("\n %s \\\n"
25
- " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n"
22
+ static void print_usage(int, char ** argv) {
23
+ LOG("\nexample usage:\n");
24
+ LOG("\n %s \\\n"
25
+ " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n"
26
26
  " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
27
27
  " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]);
28
- LOG_TEE("\n");
28
+ LOG("\n");
29
29
  }
30
30
 
31
31
  struct Stats {
@@ -37,13 +37,13 @@ struct Stats {
37
37
  class IMatrixCollector {
38
38
  public:
39
39
  IMatrixCollector() = default;
40
- void set_params(gpt_params params) { m_params = std::move(params); }
40
+ void set_params(common_params params) { m_params = std::move(params); }
41
41
  bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
42
42
  void save_imatrix(int ncall = -1) const;
43
43
  bool load_imatrix(const char * file_name);
44
44
  private:
45
45
  std::unordered_map<std::string, Stats> m_stats;
46
- gpt_params m_params;
46
+ common_params m_params;
47
47
  std::mutex m_mutex;
48
48
  int m_last_call = 0;
49
49
  std::vector<float> m_src1_data;
@@ -126,12 +126,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
126
126
  e.counts.resize(src1->ne[0]*n_as, 0);
127
127
  }
128
128
  else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
129
- fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
129
+ LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
130
130
  exit(1); //GGML_ABORT("fatal error");
131
131
  }
132
- if (m_params.verbosity > 1) {
133
- printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
134
- }
132
+ LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
135
133
  // loop over all possible experts, regardless if they are used or not in the batch
136
134
  for (int ex = 0; ex < n_as; ++ex) {
137
135
  size_t e_start = ex*src1->ne[0];
@@ -152,7 +150,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
152
150
  e.values[e_start + j] += x[j]*x[j];
153
151
  e.counts[e_start + j]++;
154
152
  if (!std::isfinite(e.values[e_start + j])) {
155
- fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str());
153
+ LOG("\n");
154
+ LOG_ERR("%f detected in %s\n", e.values[e_start + j], wname.c_str());
156
155
  exit(1);
157
156
  }
158
157
  }
@@ -175,20 +174,18 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
175
174
  e.counts.resize(src1->ne[0], 0);
176
175
  }
177
176
  else if (e.values.size() != (size_t)src1->ne[0]) {
178
- fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
177
+ LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
179
178
  exit(1); //GGML_ABORT("fatal error");
180
179
  }
181
180
  ++e.ncall;
182
- if (m_params.verbosity > 1) {
183
- printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
184
- }
181
+ LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
185
182
  for (int row = 0; row < (int)src1->ne[1]; ++row) {
186
183
  const float * x = data + row * src1->ne[0];
187
184
  for (int j = 0; j < (int)src1->ne[0]; ++j) {
188
185
  e.values[j] += x[j]*x[j];
189
186
  e.counts[j]++;
190
187
  if (!std::isfinite(e.values[j])) {
191
- fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str());
188
+ LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str());
192
189
  exit(1);
193
190
  }
194
191
  }
@@ -240,17 +237,17 @@ void IMatrixCollector::save_imatrix(int ncall) const {
240
237
  }
241
238
 
242
239
  if (n_zeros != 0 && is_first) {
243
- fprintf(stderr, "\n");
240
+ LOG_INF("\n");
244
241
  is_first = false;
245
242
  }
246
243
 
247
244
  if (n_zeros == n_all) {
248
- fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
245
+ LOG_WRN("%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
249
246
  continue;
250
247
  }
251
248
 
252
249
  if (n_zeros > 0) {
253
- fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
250
+ LOG_WRN("%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
254
251
  continue;
255
252
  }
256
253
 
@@ -259,7 +256,7 @@ void IMatrixCollector::save_imatrix(int ncall) const {
259
256
  }
260
257
 
261
258
  if (to_store.size() < m_stats.size()) {
262
- fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
259
+ LOG_WRN("%s: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
263
260
  }
264
261
 
265
262
  std::ofstream out(fname, std::ios::binary);
@@ -291,21 +288,20 @@ void IMatrixCollector::save_imatrix(int ncall) const {
291
288
  out.write(m_params.prompt_file.c_str(), len);
292
289
  }
293
290
 
294
- if (m_params.verbosity > 0) {
295
- fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
296
- }
291
+ LOGV(1, "\n");
292
+ LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
297
293
  }
298
294
 
299
295
  bool IMatrixCollector::load_imatrix(const char * fname) {
300
296
  std::ifstream in(fname, std::ios::binary);
301
297
  if (!in) {
302
- printf("%s: failed to open %s\n",__func__, fname);
298
+ LOG_ERR("%s: failed to open %s\n",__func__, fname);
303
299
  return false;
304
300
  }
305
301
  int n_entries;
306
302
  in.read((char*)&n_entries, sizeof(n_entries));
307
303
  if (in.fail() || n_entries < 1) {
308
- printf("%s: no data in file %s\n", __func__, fname);
304
+ LOG_ERR("%s: no data in file %s\n", __func__, fname);
309
305
  return false;
310
306
  }
311
307
  for (int i = 0; i < n_entries; ++i) {
@@ -313,7 +309,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
313
309
  std::vector<char> name_as_vec(len+1);
314
310
  in.read((char *)name_as_vec.data(), len);
315
311
  if (in.fail()) {
316
- printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
312
+ LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
317
313
  return false;
318
314
  }
319
315
  name_as_vec[len] = 0;
@@ -324,7 +320,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
324
320
  int nval;
325
321
  in.read((char *)&nval, sizeof(nval));
326
322
  if (in.fail() || nval < 1) {
327
- printf("%s: failed reading number of values for entry %d\n",__func__,i);
323
+ LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i);
328
324
  m_stats = {};
329
325
  return false;
330
326
  }
@@ -337,7 +333,7 @@ bool IMatrixCollector::load_imatrix(const char * fname) {
337
333
  std::vector<float> tmp(nval);
338
334
  in.read((char*)tmp.data(), nval*sizeof(float));
339
335
  if (in.fail()) {
340
- printf("%s: failed reading data for entry %d\n",__func__,i);
336
+ LOG_ERR("%s: failed reading data for entry %d\n",__func__,i);
341
337
  m_stats = {};
342
338
  return false;
343
339
  }
@@ -432,32 +428,31 @@ static void process_logits(
432
428
  }
433
429
  }
434
430
 
435
- static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
436
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
437
- GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
431
+ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
432
+ const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
433
+ GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
438
434
  const int n_ctx = llama_n_ctx(ctx);
439
435
 
440
436
  auto tim1 = std::chrono::high_resolution_clock::now();
441
- fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
437
+ LOG_INF("%s: tokenizing the input ..\n", __func__);
442
438
 
443
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
439
+ std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
444
440
 
445
441
  auto tim2 = std::chrono::high_resolution_clock::now();
446
- fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
442
+ LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
447
443
 
448
444
  if (params.i_chunk > 0) {
449
445
  if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) {
450
- fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk);
446
+ LOG_ERR("%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk);
451
447
  return false;
452
448
  }
453
- fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx);
449
+ LOG_INF("%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx);
454
450
  tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx);
455
451
  }
456
452
 
457
453
  if (int(tokens.size()) < 2*n_ctx) {
458
- fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
459
- n_ctx);
460
- fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
454
+ LOG_ERR("%s: you need at least %d tokens for a context of %d tokens\n", __func__, 2*n_ctx, n_ctx);
455
+ LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n", __func__, tokens.size());
461
456
  return false;
462
457
  }
463
458
 
@@ -479,7 +474,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
479
474
  double nll = 0.0;
480
475
  double nll2 = 0.0;
481
476
 
482
- fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
477
+ LOG_INF("%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
483
478
 
484
479
  std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
485
480
 
@@ -501,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
501
496
  // clear the KV cache
502
497
  llama_kv_cache_clear(ctx);
503
498
 
499
+ llama_batch batch = llama_batch_init(n_batch, 0, 1);
500
+
504
501
  for (int j = 0; j < num_batches; ++j) {
505
502
  const int batch_start = start + j * n_batch;
506
503
  const int batch_size = std::min(end - batch_start, n_batch);
@@ -513,9 +510,14 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
513
510
  tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
514
511
  }
515
512
 
516
- // TODO: use batch.logits to save computations instead of relying on logits_all == true
517
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
518
- fprintf(stderr, "%s : failed to eval\n", __func__);
513
+ common_batch_clear(batch);
514
+ for (int i = 0; i < batch_size; i++) {
515
+ common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
516
+ }
517
+
518
+ if (llama_decode(ctx, batch)) {
519
+ LOG_ERR("%s : failed to eval\n", __func__);
520
+ llama_batch_free(batch);
519
521
  return false;
520
522
  }
521
523
 
@@ -528,33 +530,35 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
528
530
  }
529
531
  }
530
532
 
533
+ llama_batch_free(batch);
534
+
531
535
  const auto t_end = std::chrono::high_resolution_clock::now();
532
536
 
533
537
  if (i == 0) {
534
538
  const float t_total = std::chrono::duration<float>(t_end - t_start).count();
535
- fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
539
+ LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
536
540
  int total_seconds = (int)(t_total * n_chunk);
537
541
  if (total_seconds >= 60*60) {
538
- fprintf(stderr, "%d hours ", total_seconds / (60*60));
542
+ LOG("%d hours ", total_seconds / (60*60));
539
543
  total_seconds = total_seconds % (60*60);
540
544
  }
541
- fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
545
+ LOG("%.2f minutes\n", total_seconds / 60.0);
542
546
  }
543
547
 
544
548
  if (params.compute_ppl) {
545
549
  const int first = n_ctx/2;
546
- const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
550
+ const auto * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
547
551
  process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
548
552
  workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
549
553
  count += n_ctx - first - 1;
550
554
 
551
- printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
555
+ LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
552
556
  fflush(stdout);
553
557
 
554
558
  logits.clear();
555
559
  }
556
560
  }
557
- printf("\n");
561
+ LOG("\n");
558
562
 
559
563
  if (params.compute_ppl) {
560
564
  nll2 /= count;
@@ -563,9 +567,9 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
563
567
  nll2 -= nll * nll;
564
568
  if (nll2 > 0) {
565
569
  nll2 = sqrt(nll2/(count-1));
566
- printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
570
+ LOG("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
567
571
  } else {
568
- printf("Unexpected negative standard deviation of log(prob)\n");
572
+ LOG("Unexpected negative standard deviation of log(prob)\n");
569
573
  }
570
574
  }
571
575
 
@@ -573,31 +577,32 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
573
577
  }
574
578
 
575
579
  int main(int argc, char ** argv) {
576
- gpt_params params;
580
+ common_params params;
577
581
 
578
582
  params.n_ctx = 512;
579
583
  params.logits_all = true;
580
- params.verbosity = 1;
584
+ params.escape = false;
581
585
 
582
- if (!gpt_params_parse(argc, argv, params)) {
583
- print_usage(argc, argv, params);
586
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
584
587
  return 1;
585
588
  }
586
589
 
590
+ common_init();
591
+
587
592
  params.n_batch = std::min(params.n_batch, params.n_ctx);
588
593
 
589
594
  g_collector.set_params(params);
590
595
 
591
596
  for (const auto & in_file : params.in_files) {
592
- printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str());
597
+ LOG_INF("%s : loading imatrix from '%s'\n", __func__, in_file.c_str());
593
598
  if (!g_collector.load_imatrix(in_file.c_str())) {
594
- fprintf(stderr, "%s : failed to load %s\n", __func__, in_file.c_str());
599
+ LOG_ERR("%s : failed to load %s\n", __func__, in_file.c_str());
595
600
  return 1;
596
601
  }
597
602
  }
598
603
 
599
604
  if (params.in_files.size() > 1) {
600
- printf("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
605
+ LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
601
606
  g_collector.save_imatrix();
602
607
  }
603
608
 
@@ -611,25 +616,25 @@ int main(int argc, char ** argv) {
611
616
  params.warmup = false;
612
617
 
613
618
  // init
614
- llama_model * model;
615
- llama_context * ctx;
619
+ common_init_result llama_init = common_init_from_params(params);
616
620
 
617
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
621
+ llama_model * model = llama_init.model;
622
+ llama_context * ctx = llama_init.context;
618
623
  if (model == nullptr || ctx == nullptr) {
619
- fprintf(stderr, "%s : failed to init\n", __func__);
624
+ LOG_ERR("%s : failed to init\n", __func__);
620
625
  return 1;
621
626
  }
622
627
 
623
628
  const int n_ctx_train = llama_n_ctx_train(model);
624
629
  if (params.n_ctx > n_ctx_train) {
625
- fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
630
+ LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
626
631
  __func__, n_ctx_train, params.n_ctx);
627
632
  }
628
633
 
629
634
  // print system information
630
635
  {
631
- fprintf(stderr, "\n");
632
- fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
636
+ LOG_INF("\n");
637
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
633
638
  }
634
639
 
635
640
  if (!compute_imatrix(ctx, params)) {
@@ -638,7 +643,8 @@ int main(int argc, char ** argv) {
638
643
 
639
644
  g_collector.save_imatrix();
640
645
 
641
- llama_print_timings(ctx);
646
+ LOG("\n");
647
+ llama_perf_context_print(ctx);
642
648
 
643
649
  llama_free(ctx);
644
650
  llama_free_model(model);