@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,18 +1,21 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
3
+ #include "log.h"
2
4
  #include "llama.h"
3
5
 
6
+ #include <algorithm>
7
+ #include <array>
8
+ #include <atomic>
4
9
  #include <cmath>
5
10
  #include <cstdio>
6
11
  #include <cstring>
7
12
  #include <ctime>
13
+ #include <fstream>
14
+ #include <mutex>
15
+ #include <random>
8
16
  #include <sstream>
9
17
  #include <thread>
10
- #include <mutex>
11
- #include <atomic>
12
18
  #include <vector>
13
- #include <array>
14
- #include <fstream>
15
- #include <sstream>
16
19
 
17
20
  #if defined(_MSC_VER)
18
21
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -31,55 +34,6 @@ struct results_log_softmax {
31
34
  float prob;
32
35
  };
33
36
 
34
- static void write_logfile(
35
- const llama_context * ctx, const gpt_params & params, const llama_model * model,
36
- const struct results_perplexity & results
37
- ) {
38
- if (params.logdir.empty()) {
39
- return;
40
- }
41
-
42
- if (params.hellaswag) {
43
- fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
44
- return;
45
- }
46
-
47
- const std::string timestamp = string_get_sortable_timestamp();
48
-
49
- const bool success = fs_create_directory_with_parents(params.logdir);
50
- if (!success) {
51
- fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
52
- __func__, params.logdir.c_str());
53
- return;
54
- }
55
-
56
- const std::string logfile_path = params.logdir + timestamp + ".yml";
57
- FILE * logfile = fopen(logfile_path.c_str(), "w");
58
-
59
- if (logfile == NULL) {
60
- fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
61
- return;
62
- }
63
-
64
- fprintf(logfile, "binary: main\n");
65
- char model_desc[128];
66
- llama_model_desc(model, model_desc, sizeof(model_desc));
67
- yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc);
68
-
69
- fprintf(logfile, "\n");
70
- fprintf(logfile, "######################\n");
71
- fprintf(logfile, "# Perplexity Results #\n");
72
- fprintf(logfile, "######################\n");
73
- fprintf(logfile, "\n");
74
-
75
- yaml_dump_vector_float(logfile, "logits", results.logits);
76
- fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
77
- yaml_dump_vector_float(logfile, "probs", results.probs);
78
-
79
- llama_dump_timing_info_yaml(logfile, ctx);
80
- fclose(logfile);
81
- }
82
-
83
37
  static std::vector<float> softmax(const std::vector<float>& logits) {
84
38
  std::vector<float> probs(logits.size());
85
39
  float max_logit = logits[0];
@@ -166,7 +120,7 @@ static void process_logits(
166
120
  break;
167
121
  }
168
122
  lock.unlock();
169
- const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
123
+ const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]);
170
124
  const double v = -results.log_softmax;
171
125
  local_nll += v;
172
126
  local_nll2 += v*v;
@@ -200,7 +154,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
200
154
  break;
201
155
  }
202
156
  lock.unlock();
203
- const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
157
+ const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
204
158
  local_nll += v;
205
159
  local_nll2 += v*v;
206
160
  }
@@ -278,7 +232,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
278
232
  kld.sum_kld += sum;
279
233
  kld.sum_kld2 += sum*sum;
280
234
  ++kld.count;
281
- if (imax == imax_base) ++kld.n_same_top;
235
+ if (imax == imax_base) {
236
+ ++kld.n_same_top;
237
+ }
282
238
 
283
239
  const float p_base = expf(-nll_base);
284
240
  const float p = expf(-nll);
@@ -320,7 +276,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
320
276
  break;
321
277
  }
322
278
  lock.unlock();
323
- std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
279
+ std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
324
280
  kld_values[i] = (float)v.first;
325
281
  p_diff_values[i] = v.second;
326
282
  }
@@ -334,25 +290,25 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
334
290
  }
335
291
  }
336
292
 
337
- static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
293
+ static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) {
338
294
  // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
339
295
  // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
340
296
  // Output: `perplexity: 13.5106 [114/114]`
341
297
  // BOS tokens will be added for each chunk before eval
342
298
 
343
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
344
- GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
299
+ const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
300
+ GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
345
301
 
346
- fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
302
+ LOG_INF("%s: tokenizing the input ..\n", __func__);
347
303
 
348
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
304
+ std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
349
305
 
350
306
  const int n_ctx = llama_n_ctx(ctx);
351
307
 
352
308
  if (int(tokens.size()) < 2*n_ctx) {
353
- fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
309
+ LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
354
310
  n_ctx);
355
- fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
311
+ LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
356
312
  return {std::move(tokens), 0., {}, {}};
357
313
  }
358
314
 
@@ -363,16 +319,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
363
319
  prob_history.resize(tokens.size());
364
320
 
365
321
  if (params.ppl_stride <= 0) {
366
- fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
322
+ LOG_ERR("%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
367
323
  return {tokens, -1, logit_history, prob_history};
368
324
  }
369
325
 
370
326
  const int calc_chunk = n_ctx;
371
327
 
372
- fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
328
+ LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
373
329
 
374
330
  if (int(tokens.size()) <= calc_chunk) {
375
- fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
331
+ LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
376
332
  tokens.size(), n_ctx, params.ppl_stride);
377
333
  return {tokens, -1, logit_history, prob_history};
378
334
  }
@@ -380,20 +336,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
380
336
  const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
381
337
 
382
338
  const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
383
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
384
339
  const int n_batch = params.n_batch;
385
340
 
341
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
342
+
386
343
  int count = 0;
387
344
  double nll = 0.0;
388
345
 
389
- fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
346
+ LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
390
347
 
391
348
  for (int i = 0; i < n_chunk; ++i) {
392
349
  const int start = i * params.ppl_stride;
393
350
  const int end = start + calc_chunk;
394
351
 
395
352
  const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
396
- //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
353
+ //LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
397
354
 
398
355
  std::vector<float> logits;
399
356
 
@@ -402,14 +359,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
402
359
  // clear the KV cache
403
360
  llama_kv_cache_clear(ctx);
404
361
 
362
+ llama_batch batch = llama_batch_init(n_batch, 0, 1);
363
+
405
364
  for (int j = 0; j < num_batches; ++j) {
406
365
  const int batch_start = start + j * n_batch;
407
366
  const int batch_size = std::min(end - batch_start, n_batch);
408
367
 
409
- //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
410
- // TODO: use llama_batch.logits instead of relying on logits_all == true
411
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
412
- //fprintf(stderr, "%s : failed to eval\n", __func__);
368
+ common_batch_clear(batch);
369
+ for (int i = 0; i < batch_size; i++) {
370
+ common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
371
+ }
372
+
373
+ //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
374
+ if (llama_decode(ctx, batch)) {
375
+ //LOG_ERR("%s : failed to eval\n", __func__);
376
+ llama_batch_free(batch);
413
377
  return {tokens, -1, logit_history, prob_history};
414
378
  }
415
379
 
@@ -421,34 +385,35 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
421
385
  tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
422
386
  }
423
387
 
424
- const auto batch_logits = llama_get_logits(ctx);
425
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
388
+ const auto * batch_logits = llama_get_logits(ctx);
389
+ logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
426
390
 
427
391
  if (j == 0) {
428
392
  tokens[batch_start] = token_org;
429
393
  }
430
394
  }
431
395
 
396
+ llama_batch_free(batch);
397
+
432
398
  const auto t_end = std::chrono::high_resolution_clock::now();
433
399
 
434
400
  if (i == 0) {
435
401
  const float t_total = std::chrono::duration<float>(t_end - t_start).count();
436
- fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
402
+ LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
437
403
  int total_seconds = (int)(t_total * n_chunk);
438
404
  if (total_seconds >= 60*60) {
439
- fprintf(stderr, "%d hours ", total_seconds / (60*60));
405
+ LOG("%d hours ", total_seconds / (60*60));
440
406
  total_seconds = total_seconds % (60*60);
441
407
  }
442
- fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
408
+ LOG("%.2f minutes\n", total_seconds / 60.0);
443
409
  }
444
410
 
445
- //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
411
+ //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
446
412
  for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
447
-
448
413
  // Calculate probability of next token, given the previous ones.
449
414
  const std::vector<float> tok_logits(
450
- logits.begin() + (j + 0) * n_vocab,
451
- logits.begin() + (j + 1) * n_vocab);
415
+ logits.begin() + size_t(j + 0) * n_vocab,
416
+ logits.begin() + size_t(j + 1) * n_vocab);
452
417
 
453
418
  const float prob = softmax(tok_logits)[tokens[start + j + 1]];
454
419
  logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
@@ -459,18 +424,17 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
459
424
  }
460
425
  // perplexity is e^(average negative log-likelihood)
461
426
  if (params.ppl_output_type == 0) {
462
- printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
427
+ LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
463
428
  } else {
464
- printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
429
+ LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
465
430
  }
466
- fflush(stdout);
467
431
  }
468
- printf("\n");
432
+ LOG("\n");
469
433
 
470
434
  return {tokens, std::exp(nll / count), logit_history, prob_history};
471
435
  }
472
436
 
473
- static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
437
+ static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
474
438
  if (params.ppl_stride > 0) {
475
439
  return perplexity_v2(ctx, params);
476
440
  }
@@ -480,33 +444,33 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
480
444
  // Output: `perplexity: 13.5106 [114/114]`
481
445
  // BOS tokens will be added for each chunk before eval
482
446
 
483
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
484
- GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
447
+ const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
448
+ GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
485
449
 
486
450
  std::ofstream logits_stream;
487
451
  if (!params.logits_file.empty()) {
488
452
  logits_stream.open(params.logits_file.c_str(), std::ios::binary);
489
453
  if (!logits_stream.is_open()) {
490
- fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
454
+ LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
491
455
  return {};
492
456
  }
493
- fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
457
+ LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
494
458
  logits_stream.write("_logits_", 8);
495
459
  logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
496
460
  }
497
461
 
498
462
  auto tim1 = std::chrono::high_resolution_clock::now();
499
- fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
463
+ LOG_INF("%s: tokenizing the input ..\n", __func__);
500
464
 
501
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
465
+ std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
502
466
 
503
467
  auto tim2 = std::chrono::high_resolution_clock::now();
504
- fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
468
+ LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
505
469
 
506
470
  if (int(tokens.size()) < 2*n_ctx) {
507
- fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
471
+ LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
508
472
  n_ctx);
509
- fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
473
+ LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
510
474
  return {std::move(tokens), 0., {}, {}};
511
475
  }
512
476
 
@@ -519,9 +483,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
519
483
  const int n_chunk_max = tokens.size() / n_ctx;
520
484
 
521
485
  const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
522
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
523
486
  const int n_batch = params.n_batch;
524
487
 
488
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
489
+
525
490
  int count = 0;
526
491
  double nll = 0.0;
527
492
  double nll2 = 0.0;
@@ -536,10 +501,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
536
501
 
537
502
  std::vector<float> logits;
538
503
  if (num_batches > 1) {
539
- logits.reserve((size_t)n_ctx * n_vocab);
504
+ logits.reserve(size_t(n_ctx) * n_vocab);
540
505
  }
541
506
 
542
- fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
507
+ LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
543
508
 
544
509
  std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
545
510
 
@@ -612,13 +577,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
612
577
  }
613
578
 
614
579
  if (llama_decode(ctx, batch)) {
615
- fprintf(stderr, "%s : failed to eval\n", __func__);
580
+ LOG_INF("%s : failed to eval\n", __func__);
616
581
  return {tokens, -1, logit_history, prob_history};
617
582
  }
618
583
 
619
584
  if (num_batches > 1 && n_outputs > 0) {
620
585
  const auto * batch_logits = llama_get_logits(ctx);
621
- logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
586
+ logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
622
587
  }
623
588
  }
624
589
 
@@ -627,13 +592,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
627
592
  llama_synchronize(ctx);
628
593
  const auto t_end = std::chrono::high_resolution_clock::now();
629
594
  const float t_total = std::chrono::duration<float>(t_end - t_start).count();
630
- fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
595
+ LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
631
596
  int total_seconds = (int)(t_total*n_chunk/n_seq);
632
597
  if (total_seconds >= 60*60) {
633
- fprintf(stderr, "%d hours ", total_seconds / (60*60));
598
+ LOG("%d hours ", total_seconds / (60*60));
634
599
  total_seconds = total_seconds % (60*60);
635
600
  }
636
- fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
601
+ LOG("%.2f minutes\n", total_seconds / 60.0);
637
602
  }
638
603
 
639
604
  for (int seq = 0; seq < n_seq_batch; seq++) {
@@ -655,19 +620,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
655
620
 
656
621
  // perplexity is e^(average negative log-likelihood)
657
622
  if (params.ppl_output_type == 0) {
658
- printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
623
+ LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
659
624
  } else {
660
625
  double av = nll/count;
661
626
  double av2 = nll2/count - av*av;
662
- if (av2 > 0) av2 = sqrt(av2/(count-1));
663
- printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
627
+ if (av2 > 0) {
628
+ av2 = sqrt(av2/(count-1));
629
+ }
630
+ LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
664
631
  }
665
632
  }
666
- fflush(stdout);
667
633
 
668
634
  logits.clear();
669
635
  }
670
- printf("\n");
636
+ LOG("\n");
671
637
 
672
638
  nll2 /= count;
673
639
  nll /= count;
@@ -675,9 +641,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
675
641
  nll2 -= nll * nll;
676
642
  if (nll2 > 0) {
677
643
  nll2 = sqrt(nll2/(count-1));
678
- printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
644
+ LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
679
645
  } else {
680
- printf("Unexpected negative standard deviation of log(prob)\n");
646
+ LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
681
647
  }
682
648
 
683
649
  llama_batch_free(batch);
@@ -685,10 +651,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
685
651
  return {tokens, ppl, logit_history, prob_history};
686
652
  }
687
653
 
688
- static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
654
+ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
689
655
  int prev_outputs = 0;
690
- for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
691
- const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
656
+ for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
657
+ const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
692
658
 
693
659
  llama_batch batch_view = {
694
660
  n_tokens,
@@ -698,12 +664,11 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
698
664
  batch.n_seq_id + i,
699
665
  batch.seq_id + i,
700
666
  batch.logits + i,
701
- 0, 0, 0, // unused
702
667
  };
703
668
 
704
669
  const int ret = llama_decode(ctx, batch_view);
705
670
  if (ret != 0) {
706
- LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
671
+ LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
707
672
  return false;
708
673
  }
709
674
 
@@ -712,7 +677,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
712
677
  n_outputs += batch_view.logits[i] != 0;
713
678
  }
714
679
 
715
- memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
680
+ memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
716
681
 
717
682
  prev_outputs += n_outputs;
718
683
  }
@@ -727,7 +692,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
727
692
  if (eval_results.size() != eval_pairs.size()) {
728
693
  eval_results.resize(eval_pairs.size());
729
694
  }
730
- if (eval_pairs.empty()) return;
695
+ if (eval_pairs.empty()) {
696
+ return;
697
+ }
731
698
 
732
699
  size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
733
700
 
@@ -735,11 +702,13 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
735
702
  auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
736
703
  float local_logprobs[K_TOKEN_CHUNK];
737
704
  while (true) {
738
- size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
739
- if (first >= eval_results.size()) break;
740
- size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
705
+ const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
706
+ if (first >= eval_results.size()) {
707
+ break;
708
+ }
709
+ const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
741
710
  for (size_t i = first; i < last; ++i) {
742
- auto logits = batch_logits + eval_pairs[i].first * n_vocab;
711
+ const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
743
712
  float max_logit = logits[0];
744
713
  for (int j = 1; j < n_vocab; ++j) {
745
714
  max_logit = std::max(max_logit, logits[j]);
@@ -762,7 +731,7 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
762
731
  }
763
732
  }
764
733
 
765
- static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
734
+ static void hellaswag_score(llama_context * ctx, const common_params & params) {
766
735
  // Calculates hellaswag score (acc_norm) from prompt
767
736
  //
768
737
  // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@@ -789,15 +758,15 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
789
758
  }
790
759
 
791
760
  if (prompt_lines.size() % 6 != 0) {
792
- fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
761
+ LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n", __func__);
793
762
  return;
794
763
  }
795
764
 
796
765
  size_t hs_task_count = prompt_lines.size()/6;
797
- fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
766
+ LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
798
767
 
799
768
  const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
800
- fprintf(stderr, "================================= is_spm = %d\n", is_spm);
769
+ LOG_INF("================================= is_spm = %d\n", is_spm);
801
770
 
802
771
  // The tasks should be randomized so the score stabilizes quickly.
803
772
  bool randomize_tasks = true;
@@ -824,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
824
793
  std::vector<llama_token> seq_tokens[4];
825
794
  };
826
795
 
827
- fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
796
+ LOG_INF("%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
828
797
 
829
798
  // Select and read data from prompt lines
830
799
  std::vector<hs_data_t> hs_data(hs_task_count);
@@ -843,7 +812,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
843
812
  hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
844
813
  for (size_t j = 0; j < 4; j++) {
845
814
  hs_cur.ending[j] = prompt_lines[idx*6+2+j];
846
- hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
815
+ hs_cur.seq_tokens[j] = common_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
847
816
  }
848
817
 
849
818
  // determine the common prefix of the endings
@@ -870,16 +839,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
870
839
  }
871
840
  }
872
841
 
873
- fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
842
+ LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
874
843
 
875
- printf("\ntask\tacc_norm\n");
844
+ LOG("\ntask\tacc_norm\n");
876
845
 
877
846
  double acc = 0.0f;
878
847
 
879
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
880
848
  const int n_ctx = llama_n_ctx(ctx);
881
849
  const int n_batch = params.n_batch;
882
850
 
851
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
852
+
883
853
  const int max_tasks_per_batch = 32;
884
854
  const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
885
855
 
@@ -887,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
887
857
 
888
858
  std::vector<float> tok_logits(n_vocab);
889
859
  // TODO: this could be made smaller; it's currently the worst-case size
890
- std::vector<float> batch_logits(n_vocab*n_ctx);
860
+ std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
891
861
 
892
862
  std::vector<std::pair<size_t, llama_token>> eval_pairs;
893
863
  std::vector<float> eval_results;
@@ -899,7 +869,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
899
869
  size_t i1 = i0;
900
870
  size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
901
871
 
902
- llama_batch_clear(batch);
872
+ common_batch_clear(batch);
903
873
 
904
874
  // batch as much tasks as possible into the available context
905
875
  // each task has 4 unique sequence ids - one for each ending
@@ -915,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
915
885
  }
916
886
 
917
887
  for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
918
- llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
888
+ common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
919
889
  }
920
890
  batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
921
891
  n_logits += 1;
@@ -925,7 +895,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
925
895
  // TODO: don't evaluate the last token of each sequence
926
896
  for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
927
897
  const bool needs_logits = i < seq_tokens_size - 1;
928
- llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
898
+ common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
929
899
  n_logits += needs_logits;
930
900
  }
931
901
  }
@@ -940,7 +910,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
940
910
  }
941
911
 
942
912
  if (i0 == i1) {
943
- fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
913
+ LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
944
914
  return;
945
915
  }
946
916
 
@@ -948,7 +918,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
948
918
 
949
919
  // decode all tasks [i0, i1)
950
920
  if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
951
- fprintf(stderr, "%s: llama_decode() failed\n", __func__);
921
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
952
922
  return;
953
923
  }
954
924
 
@@ -974,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
974
944
  auto & hs_cur = hs_data[i];
975
945
 
976
946
  // get the logits of the last token of the common prefix
977
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
947
+ std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float));
978
948
 
979
949
  const auto first_probs = softmax(tok_logits);
980
950
 
@@ -998,7 +968,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
998
968
  }
999
969
  }
1000
970
 
1001
- //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
971
+ //LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
1002
972
 
1003
973
  // If the gold ending got the maximum logprobe add one accuracy point
1004
974
  if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
@@ -1006,8 +976,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
1006
976
  }
1007
977
 
1008
978
  // Print the accumulated accuracy mean x 100
1009
- printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
1010
- fflush(stdout);
979
+ LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
1011
980
  }
1012
981
 
1013
982
  i0 = i1 - 1;
@@ -1015,7 +984,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
1015
984
 
1016
985
  llama_batch_free(batch);
1017
986
 
1018
- printf("\n");
987
+ LOG("\n");
1019
988
  }
1020
989
 
1021
990
  struct winogrande_entry {
@@ -1059,7 +1028,7 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
1059
1028
  }
1060
1029
  }
1061
1030
  if (ipos != 4) {
1062
- printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
1031
+ LOG_ERR("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
1063
1032
  continue;
1064
1033
  }
1065
1034
  auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
@@ -1073,13 +1042,13 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
1073
1042
  if (sentence[where] == '_') break;
1074
1043
  }
1075
1044
  if (where == int(sentence.size())) {
1076
- printf("%s: no _ in <%s>\n", __func__, sentence.c_str());
1045
+ LOG_ERR("%s: no _ in <%s>\n", __func__, sentence.c_str());
1077
1046
  continue;
1078
1047
  }
1079
1048
  std::istringstream stream(answer.c_str());
1080
1049
  int i_answer; stream >> i_answer;
1081
1050
  if (stream.fail() || i_answer < 1 || i_answer > 2) {
1082
- printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
1051
+ LOG_ERR("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
1083
1052
  continue;
1084
1053
  }
1085
1054
  result.emplace_back();
@@ -1102,20 +1071,20 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
1102
1071
  * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
1103
1072
  *
1104
1073
  */
1105
- static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1074
+ static void winogrande_score(llama_context * ctx, const common_params & params) {
1106
1075
 
1107
1076
  constexpr int k_min_trailing_ctx = 3;
1108
1077
 
1109
1078
  auto data = load_winogrande_from_csv(params.prompt);
1110
1079
  if (data.empty()) {
1111
- fprintf(stderr, "%s: no tasks\n", __func__);
1080
+ LOG_ERR("%s: no tasks\n", __func__);
1112
1081
  return;
1113
1082
  }
1114
1083
 
1115
- fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size());
1084
+ LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, data.size());
1116
1085
 
1117
1086
  if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
1118
- fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
1087
+ LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
1119
1088
  std::mt19937 rng(1);
1120
1089
  std::vector<int> aux(data.size());
1121
1090
  for (int i = 0; i < int(data.size()); ++i) {
@@ -1133,11 +1102,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1133
1102
  data = std::move(selected);
1134
1103
  }
1135
1104
 
1136
- fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
1105
+ LOG_INF("%s : tokenizing selected tasks\n", __func__);
1137
1106
 
1138
1107
  for (auto & task : data) {
1139
- task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true);
1140
- task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true);
1108
+ task.seq_tokens[0] = common_tokenize(ctx, task.first + task.choices[0] + task.second, true);
1109
+ task.seq_tokens[1] = common_tokenize(ctx, task.first + task.choices[1] + task.second, true);
1141
1110
 
1142
1111
  task.common_prefix = 0;
1143
1112
  for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
@@ -1152,16 +1121,17 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1152
1121
  task.seq_tokens[0].size() - task.common_prefix +
1153
1122
  task.seq_tokens[1].size() - task.common_prefix;
1154
1123
 
1155
- task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size();
1156
- task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size();
1124
+ task.n_base1 = common_tokenize(ctx, task.first + task.choices[0], true).size();
1125
+ task.n_base2 = common_tokenize(ctx, task.first + task.choices[1], true).size();
1157
1126
  }
1158
1127
 
1159
- fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
1128
+ LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
1160
1129
 
1161
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1162
1130
  const int n_ctx = llama_n_ctx(ctx);
1163
1131
  const int n_batch = params.n_batch;
1164
1132
 
1133
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1134
+
1165
1135
  const int max_tasks_per_batch = 128;
1166
1136
  const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1167
1137
 
@@ -1169,7 +1139,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1169
1139
 
1170
1140
  std::vector<float> tok_logits(n_vocab);
1171
1141
  // TODO: this could be made smaller; it's currently the worst-case size
1172
- std::vector<float> batch_logits(n_vocab*n_ctx);
1142
+ std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1173
1143
 
1174
1144
  std::vector<std::pair<size_t, llama_token>> eval_pairs;
1175
1145
  std::vector<float> eval_results;
@@ -1184,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1184
1154
  size_t i1 = i0;
1185
1155
  size_t i_logits = 0;
1186
1156
 
1187
- llama_batch_clear(batch);
1157
+ common_batch_clear(batch);
1188
1158
 
1189
1159
  while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
1190
1160
  int n_logits = 0;
@@ -1194,7 +1164,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1194
1164
  }
1195
1165
 
1196
1166
  for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1197
- llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1167
+ common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1198
1168
  }
1199
1169
  batch.logits[batch.n_tokens - 1] = true;
1200
1170
  n_logits += 1;
@@ -1202,7 +1172,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1202
1172
  for (int s = 0; s < 2; ++s) {
1203
1173
  // TODO: end before the last token, no need to predict past the end of the sequences
1204
1174
  for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1205
- llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1175
+ common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1206
1176
  n_logits += 1;
1207
1177
  }
1208
1178
  }
@@ -1217,7 +1187,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1217
1187
  }
1218
1188
 
1219
1189
  if (i0 == i1) {
1220
- fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
1190
+ LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
1221
1191
  return;
1222
1192
  }
1223
1193
 
@@ -1225,7 +1195,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1225
1195
 
1226
1196
  // decode all tasks [i0, i1)
1227
1197
  if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1228
- fprintf(stderr, "%s: llama_decode() failed\n", __func__);
1198
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
1229
1199
  return;
1230
1200
  }
1231
1201
 
@@ -1285,20 +1255,20 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
1285
1255
  ++n_done;
1286
1256
 
1287
1257
  // print the accumulated accuracy mean x 100
1288
- printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1289
- fflush(stdout);
1258
+ LOG("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1290
1259
  }
1291
1260
 
1292
1261
  i0 = i1 - 1;
1293
1262
  }
1294
1263
 
1295
- printf("\n");
1264
+ LOG("\n");
1296
1265
 
1297
1266
  if (n_done < 100) return;
1298
1267
 
1299
1268
  const float p = 1.f*n_correct/n_done;
1300
1269
  const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
1301
- printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
1270
+
1271
+ LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
1302
1272
  }
1303
1273
 
1304
1274
  static bool deserialize_string(std::istream & in, std::string & str) {
@@ -1347,7 +1317,7 @@ struct multiple_choice_task {
1347
1317
  static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
1348
1318
  if (task.question.empty() || task.mc1.answers.empty()) {
1349
1319
  if (log_error) {
1350
- printf("%s: found bad task with empty question and/or answers\n", __func__);
1320
+ LOG_ERR("%s: found bad task with empty question and/or answers\n", __func__);
1351
1321
  }
1352
1322
  return false;
1353
1323
  }
@@ -1355,11 +1325,11 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
1355
1325
  for (auto& answer : task.mc1.answers) {
1356
1326
  if (answer.empty()) {
1357
1327
  if (log_error) {
1358
- printf("%s: found empty answer\n", __func__);
1328
+ LOG_ERR("%s: found empty answer\n", __func__);
1359
1329
  }
1360
1330
  return false;
1361
1331
  }
1362
- task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true));
1332
+ task.seq_tokens.emplace_back(::common_tokenize(ctx, task.question + " " + answer, true));
1363
1333
  }
1364
1334
  auto min_len = task.seq_tokens.front().size();
1365
1335
  for (auto& seq : task.seq_tokens) {
@@ -1403,20 +1373,20 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
1403
1373
  // git@hf.co:datasets/Stevross/mmlu
1404
1374
  // https://huggingface.co/datasets/truthful_qa
1405
1375
  //
1406
- static void multiple_choice_score(llama_context * ctx, const gpt_params & params) {
1376
+ static void multiple_choice_score(llama_context * ctx, const common_params & params) {
1407
1377
 
1408
1378
  std::istringstream strstream(params.prompt);
1409
1379
  uint32_t n_task;
1410
1380
  strstream.read((char *)&n_task, sizeof(n_task));
1411
1381
  if (strstream.fail() || n_task == 0) {
1412
- printf("%s: no tasks\n", __func__);
1382
+ LOG_ERR("%s: no tasks\n", __func__);
1413
1383
  return;
1414
1384
  }
1415
- printf("%s: there are %u tasks in prompt\n", __func__, n_task);
1385
+ LOG_INF("%s: there are %u tasks in prompt\n", __func__, n_task);
1416
1386
  std::vector<uint32_t> task_pos(n_task);
1417
1387
  strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
1418
1388
  if (strstream.fail()) {
1419
- printf("%s: failed to read task positions from prompt\n", __func__);
1389
+ LOG_ERR("%s: failed to read task positions from prompt\n", __func__);
1420
1390
  return;
1421
1391
  }
1422
1392
 
@@ -1424,21 +1394,21 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1424
1394
  if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
1425
1395
  // Use all tasks
1426
1396
  tasks.resize(n_task);
1427
- printf("%s: reading tasks", __func__);
1397
+ LOG_INF("%s: reading tasks", __func__);
1428
1398
  int n_dot = std::max((int) n_task/100, 1);
1429
1399
  int i = 0;
1430
1400
  for (auto& task : tasks) {
1431
1401
  ++i;
1432
1402
  if (!task.deserialize(strstream)) {
1433
- printf("%s: failed to read task %d of %u\n", __func__, i, n_task);
1403
+ LOG_ERR("%s: failed to read task %d of %u\n", __func__, i, n_task);
1434
1404
  return;
1435
1405
  }
1436
- if (i%n_dot == 0) printf(".");
1406
+ if (i%n_dot == 0) LOG(".");
1437
1407
  }
1438
- printf("done\n");
1408
+ LOG("done\n");
1439
1409
  }
1440
1410
  else {
1441
- printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
1411
+ LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
1442
1412
  std::mt19937 rng(1);
1443
1413
  std::vector<int> aux(n_task);
1444
1414
  for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
@@ -1451,18 +1421,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1451
1421
  aux.pop_back();
1452
1422
  strstream.seekg(task_pos[idx], std::ios::beg);
1453
1423
  if (!task.deserialize(strstream)) {
1454
- printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
1424
+ LOG_ERR("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
1455
1425
  return;
1456
1426
  }
1457
1427
  }
1458
1428
  n_task = params.multiple_choice_tasks;
1459
1429
  }
1460
1430
 
1461
- printf("%s: preparing task data", __func__);
1462
- fflush(stdout);
1431
+ LOG_INF("%s: preparing task data", __func__);
1463
1432
  if (n_task > 500) {
1464
- printf("...");
1465
- fflush(stdout);
1433
+ LOG("...");
1466
1434
  std::atomic<int> counter(0);
1467
1435
  std::atomic<int> n_bad(0);
1468
1436
  auto prepare = [&counter, &n_bad, &tasks, ctx] () {
@@ -1486,11 +1454,10 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1486
1454
  for (auto& w : workers) w = std::thread(prepare);
1487
1455
  prepare();
1488
1456
  for (auto& w : workers) w.join();
1489
- printf("done\n");
1490
- fflush(stdout);
1457
+ LOG("done\n");
1491
1458
  int nbad = n_bad;
1492
1459
  if (nbad > 0) {
1493
- printf("%s: found %d malformed tasks\n", __func__, nbad);
1460
+ LOG_ERR("%s: found %d malformed tasks\n", __func__, nbad);
1494
1461
  return;
1495
1462
  }
1496
1463
  } else {
@@ -1502,28 +1469,28 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1502
1469
  return;
1503
1470
  }
1504
1471
  if (i_task%n_dot == 0) {
1505
- printf(".");
1506
- fflush(stdout);
1472
+ LOG(".");
1507
1473
  }
1508
1474
  }
1509
- printf("done\n");
1475
+ LOG("done\n");
1510
1476
  }
1511
1477
 
1512
- printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
1478
+ LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
1513
1479
 
1514
- printf("\ntask\tacc_norm\n");
1480
+ LOG("\ntask\tacc_norm\n");
1515
1481
 
1516
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1517
1482
  const int n_ctx = llama_n_ctx(ctx);
1518
1483
  const int n_batch = params.n_batch;
1519
1484
 
1485
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
1486
+
1520
1487
  const int max_tasks_per_batch = 32;
1521
1488
  const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
1522
1489
 
1523
1490
  llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
1524
1491
 
1525
1492
  std::vector<float> tok_logits(n_vocab);
1526
- std::vector<float> batch_logits(n_vocab*n_ctx);
1493
+ std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
1527
1494
 
1528
1495
  std::vector<std::pair<size_t, llama_token>> eval_pairs;
1529
1496
  std::vector<float> eval_results;
@@ -1540,7 +1507,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1540
1507
  size_t i1 = i0;
1541
1508
  size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
1542
1509
 
1543
- llama_batch_clear(batch);
1510
+ common_batch_clear(batch);
1544
1511
 
1545
1512
  // batch as much tasks as possible into the available context
1546
1513
  // each task has 4 unique sequence ids - one for each ending
@@ -1563,7 +1530,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1563
1530
 
1564
1531
  for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1565
1532
  //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1566
- llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
1533
+ common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
1567
1534
  }
1568
1535
  batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
1569
1536
  n_logits += 1;
@@ -1573,7 +1540,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1573
1540
  // TODO: don't evaluate the last token of each sequence
1574
1541
  for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
1575
1542
  const bool needs_logits = i < seq_tokens_size - 1;
1576
- llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1543
+ common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1577
1544
  n_logits += needs_logits;
1578
1545
  }
1579
1546
  }
@@ -1590,7 +1557,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1590
1557
  }
1591
1558
 
1592
1559
  if (i0 == i1) {
1593
- fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
1560
+ LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
1594
1561
  return;
1595
1562
  }
1596
1563
 
@@ -1598,7 +1565,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1598
1565
 
1599
1566
  // decode all tasks [i0, i1)
1600
1567
  if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1601
- fprintf(stderr, "%s: llama_decode() failed\n", __func__);
1568
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
1602
1569
  return;
1603
1570
  }
1604
1571
 
@@ -1622,16 +1589,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1622
1589
  // compute the logprobs for each ending of the decoded tasks
1623
1590
  for (size_t i = i0; i < i1; ++i) {
1624
1591
  auto & cur_task = tasks[i];
1625
- //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
1592
+ //LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
1626
1593
  //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
1627
1594
  // if (cur_task.mc1.labels[j] == 1) {
1628
- // printf("%d", j+1);
1595
+ // LOG("%d", j+1);
1629
1596
  // }
1630
1597
  //}
1631
- //printf("\n common_prefix: %zu\n", cur_task.common_prefix);
1598
+ //LOG("\n common_prefix: %zu\n", cur_task.common_prefix);
1632
1599
 
1633
1600
  // get the logits of the last token of the common prefix
1634
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
1601
+ std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float));
1635
1602
 
1636
1603
  const auto first_probs = softmax(tok_logits);
1637
1604
 
@@ -1640,13 +1607,13 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1640
1607
  size_t count = 1;
1641
1608
  float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
1642
1609
  for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
1643
- //printf(" %zu %g\n", ir, eval_results[ir]);
1610
+ //LOG(" %zu %g\n", ir, eval_results[ir]);
1644
1611
  ++count;
1645
1612
  log_prob += eval_results[ir++];
1646
1613
  }
1647
1614
  cur_task.log_probs[s] = log_prob / count;
1648
- //printf(" Final: %g\n", log_prob / count);
1649
- //printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
1615
+ //LOG(" Final: %g\n", log_prob / count);
1616
+ //LOG(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
1650
1617
  }
1651
1618
 
1652
1619
  // Find the ending with maximum logprob
@@ -1666,8 +1633,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1666
1633
  ++n_done;
1667
1634
 
1668
1635
  // Print the accumulated accuracy mean x 100
1669
- printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
1670
- fflush(stdout);
1636
+ LOG("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
1671
1637
  }
1672
1638
 
1673
1639
  i0 = i1 - 1;
@@ -1679,29 +1645,30 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
1679
1645
 
1680
1646
  float p = 1.f*n_correct/n_done;
1681
1647
  float sigma = sqrt(p*(1-p)/(n_done-1));
1682
- printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1648
+ LOG("\n");
1649
+ LOG_INF("Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1683
1650
  p = 1.f*n_done/n_tot_answers;
1684
1651
  sigma = sqrt(p*(1-p)/(n_done-1));
1685
- printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1652
+ LOG_INF("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
1686
1653
 
1687
- printf("\n");
1654
+ LOG_INF("\n");
1688
1655
  }
1689
1656
 
1690
- static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1657
+ static void kl_divergence(llama_context * ctx, const common_params & params) {
1691
1658
  if (params.logits_file.empty()) {
1692
- fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
1659
+ LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
1693
1660
  return;
1694
1661
  }
1695
1662
  std::ifstream in(params.logits_file.c_str(), std::ios::binary);
1696
1663
  if (!in) {
1697
- fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str());
1664
+ LOG_ERR("%s: failed to open %s\n", __func__, params.logits_file.c_str());
1698
1665
  return;
1699
1666
  }
1700
1667
  {
1701
1668
  char check[9]; check[8] = 0;
1702
1669
  in.read(check, 8);
1703
1670
  if (in.fail() || strncmp("_logits_", check, 8) != 0) {
1704
- fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
1671
+ LOG_ERR("%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
1705
1672
  return;
1706
1673
  }
1707
1674
  }
@@ -1709,39 +1676,40 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1709
1676
  uint32_t n_ctx;
1710
1677
  in.read((char *)&n_ctx, sizeof(n_ctx));
1711
1678
  if (n_ctx > llama_n_ctx(ctx)) {
1712
- fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
1679
+ LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
1713
1680
  __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
1714
1681
  }
1715
1682
 
1716
- int n_vocab, n_chunk;
1683
+ int n_vocab;
1684
+ int n_chunk;
1717
1685
  in.read((char *)&n_vocab, sizeof(n_vocab));
1718
1686
  in.read((char *)&n_chunk, sizeof(n_chunk));
1719
1687
  if (in.fail()) {
1720
- fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1688
+ LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
1721
1689
  return;
1722
1690
  }
1723
1691
  if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
1724
- fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
1692
+ LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
1725
1693
  }
1726
1694
 
1727
- std::vector<llama_token> tokens(n_ctx * n_chunk);
1695
+ std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
1728
1696
  if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
1729
- fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
1697
+ LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
1730
1698
  return;
1731
1699
  }
1732
1700
 
1733
1701
  const int n_batch = params.n_batch;
1734
1702
  const int num_batches = (n_ctx + n_batch - 1)/n_batch;
1735
1703
  const int nv = 2*((n_vocab + 1)/2) + 4;
1736
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
1737
- GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
1704
+ const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
1705
+ GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
1738
1706
 
1739
1707
  std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
1740
1708
  std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1741
1709
  std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
1742
1710
  std::vector<float> logits;
1743
1711
  if (num_batches > 1) {
1744
- logits.reserve(n_ctx * n_vocab);
1712
+ logits.reserve(size_t(n_ctx) * n_vocab);
1745
1713
  }
1746
1714
 
1747
1715
  std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@@ -1775,13 +1743,15 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1775
1743
  const auto t_start = std::chrono::high_resolution_clock::now();
1776
1744
 
1777
1745
  if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
1778
- fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i);
1746
+ LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
1779
1747
  return;
1780
1748
  }
1781
1749
 
1782
1750
  // clear the KV cache
1783
1751
  llama_kv_cache_clear(ctx);
1784
1752
 
1753
+ llama_batch batch = llama_batch_init(n_batch, 0, 1);
1754
+
1785
1755
  for (int j = 0; j < num_batches; ++j) {
1786
1756
  const int batch_start = start + j * n_batch;
1787
1757
  const int batch_size = std::min(end - batch_start, n_batch);
@@ -1794,9 +1764,14 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1794
1764
  tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
1795
1765
  }
1796
1766
 
1797
- // TODO: use llama_batch.logits instead of relying on logits_all == true
1798
- if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
1799
- fprintf(stderr, "%s : failed to eval\n", __func__);
1767
+ common_batch_clear(batch);
1768
+ for (int i = 0; i < batch_size; i++) {
1769
+ common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
1770
+ }
1771
+
1772
+ if (llama_decode(ctx, batch)) {
1773
+ LOG_ERR("%s : failed to eval\n", __func__);
1774
+ llama_batch_free(batch);
1800
1775
  return;
1801
1776
  }
1802
1777
 
@@ -1805,105 +1780,105 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1805
1780
 
1806
1781
  if (num_batches > 1) {
1807
1782
  const auto * batch_logits = llama_get_logits(ctx);
1808
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
1783
+ logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
1809
1784
  }
1810
1785
  }
1811
1786
 
1787
+ llama_batch_free(batch);
1788
+
1812
1789
  const auto t_end = std::chrono::high_resolution_clock::now();
1813
1790
 
1814
1791
  if (i == 0) {
1815
1792
  const float t_total = std::chrono::duration<float>(t_end - t_start).count();
1816
- fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
1793
+ LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
1817
1794
  int total_seconds = (int)(t_total * n_chunk);
1818
1795
  if (total_seconds >= 60*60) {
1819
- fprintf(stderr, "%d hours ", total_seconds / (60*60));
1796
+ LOG("%d hours ", total_seconds / (60*60));
1820
1797
  total_seconds = total_seconds % (60*60);
1821
1798
  }
1822
- fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
1823
-
1824
- printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
1799
+ LOG("%.2f minutes\n", total_seconds / 60.0);
1825
1800
  }
1801
+ LOG("\n");
1802
+ LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
1826
1803
 
1827
1804
  const int first = n_ctx/2;
1828
1805
  const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
1829
- process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1806
+ process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
1830
1807
  workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
1831
1808
  p_diff_ptr += n_ctx - 1 - first;
1832
1809
  kld_ptr += n_ctx - 1 - first;
1833
1810
 
1834
- printf("%4d", i+1);
1811
+ LOG("%4d", i+1);
1835
1812
 
1836
1813
  auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1837
1814
  const double ppl_val = exp(log_ppl.first);
1838
1815
  const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1839
- printf(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
1816
+ LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
1840
1817
 
1841
1818
  auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1842
1819
  const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1843
1820
  const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1844
1821
  const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1845
- printf(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
1822
+ LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
1846
1823
 
1847
1824
  auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1848
- printf(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
1825
+ LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
1849
1826
 
1850
1827
  auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1851
1828
  const double p_diff_rms_val = sqrt(p_diff_mse.first);
1852
1829
  const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1853
- printf(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1830
+ LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1854
1831
 
1855
1832
  double p_top_val = 1.*kld.n_same_top/kld.count;
1856
1833
  double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
1857
- printf(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
1858
-
1859
- printf("\n");
1834
+ LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
1860
1835
 
1861
- fflush(stdout);
1836
+ LOG("\n");
1862
1837
 
1863
1838
  logits.clear();
1864
1839
  }
1865
- printf("\n");
1840
+ LOG("\n");
1866
1841
 
1867
1842
  if (kld.count < 100) return; // we do not wish to do statistics on so few values
1868
1843
 
1869
1844
  std::sort(kld_values.begin(), kld_values.end());
1870
1845
  std::sort(p_diff_values.begin(), p_diff_values.end());
1871
1846
 
1872
- printf("====== Perplexity statistics ======\n");
1847
+ LOG("====== Perplexity statistics ======\n");
1873
1848
 
1874
1849
  auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
1875
1850
  const double ppl_val = exp(log_ppl.first);
1876
1851
  const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
1877
- printf("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
1852
+ LOG("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
1878
1853
 
1879
1854
  auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
1880
1855
  const double ppl_base_val = exp(log_ppl_base.first);
1881
1856
  const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
1882
- printf("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
1857
+ LOG("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
1883
1858
 
1884
1859
  const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
1885
- // printf("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
1860
+ // LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
1886
1861
  const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
1887
- printf("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
1862
+ LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
1888
1863
 
1889
1864
  const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
1890
1865
  const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
1891
- printf("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
1866
+ LOG("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
1892
1867
 
1893
1868
  const double ppl_ratio_val = exp(log_ppl_ratio_val);
1894
1869
  const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
1895
- printf("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
1870
+ LOG("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
1896
1871
 
1897
1872
  const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
1898
1873
  const double ppl_diff_val = ppl_val - ppl_base_val;
1899
1874
  const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
1900
- printf("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
1875
+ LOG("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
1901
1876
 
1902
- printf("\n");
1877
+ LOG("\n");
1903
1878
 
1904
- printf("====== KL divergence statistics ======\n");
1879
+ LOG("====== KL divergence statistics ======\n");
1905
1880
  auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
1906
- printf("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
1881
+ LOG("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
1907
1882
  auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
1908
1883
  : kld_values[kld_values.size()/2];
1909
1884
 
@@ -1915,67 +1890,68 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1915
1890
  return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)];
1916
1891
  };
1917
1892
 
1918
- printf("Maximum KLD: %10.6f\n", kld_values.back());
1919
- printf("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
1920
- printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1921
- printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1922
- printf("Median KLD: %10.6f\n", kld_median);
1923
- printf("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
1924
- printf(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
1925
- printf(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
1926
- printf("Minimum KLD: %10.6f\n", kld_values.front());
1893
+ LOG("Maximum KLD: %10.6f\n", kld_values.back());
1894
+ LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
1895
+ LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1896
+ LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
1897
+ LOG("Median KLD: %10.6f\n", kld_median);
1898
+ LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
1899
+ LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
1900
+ LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
1901
+ LOG("Minimum KLD: %10.6f\n", kld_values.front());
1927
1902
 
1928
- printf("\n");
1903
+ LOG("\n");
1929
1904
 
1930
- printf("====== Token probability statistics ======\n");
1905
+ LOG("====== Token probability statistics ======\n");
1931
1906
 
1932
1907
  auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
1933
- printf("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second);
1908
+ LOG("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second);
1934
1909
 
1935
1910
  auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
1936
1911
  : p_diff_values[p_diff_values.size()/2];
1937
1912
 
1938
- printf("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back());
1939
- printf("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
1940
- printf("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
1941
- printf("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
1942
- printf("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
1943
- printf("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
1944
- printf("Median Δp: %6.3lf%%\n", 100.0*p_diff_median);
1945
- printf("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
1946
- printf("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
1947
- printf(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
1948
- printf(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
1949
- printf(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
1950
- printf("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front());
1913
+ LOG("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back());
1914
+ LOG("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
1915
+ LOG("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
1916
+ LOG("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
1917
+ LOG("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
1918
+ LOG("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
1919
+ LOG("Median Δp: %6.3lf%%\n", 100.0*p_diff_median);
1920
+ LOG("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
1921
+ LOG("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
1922
+ LOG(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
1923
+ LOG(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
1924
+ LOG(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
1925
+ LOG("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front());
1951
1926
 
1952
1927
  auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
1953
- // printf("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
1928
+ // LOG("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
1954
1929
 
1955
1930
  const double p_diff_rms_val = sqrt(p_diff_mse.first);
1956
1931
  const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
1957
- printf("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1932
+ LOG("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
1958
1933
 
1959
1934
  const double same_top_p = 1.0*kld.n_same_top/kld.count;
1960
- printf("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
1961
-
1935
+ LOG("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
1962
1936
  }
1963
1937
 
1964
1938
  int main(int argc, char ** argv) {
1965
- gpt_params params;
1939
+ common_params params;
1966
1940
 
1967
1941
  params.n_ctx = 512;
1968
1942
  params.logits_all = true;
1943
+ params.escape = false;
1969
1944
 
1970
- if (!gpt_params_parse(argc, argv, params)) {
1971
- gpt_params_print_usage(argc, argv, params);
1945
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
1972
1946
  return 1;
1973
1947
  }
1974
1948
 
1949
+ common_init();
1950
+
1975
1951
  const int32_t n_ctx = params.n_ctx;
1976
1952
 
1977
1953
  if (n_ctx <= 0) {
1978
- fprintf(stderr, "%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
1954
+ LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
1979
1955
  return 1;
1980
1956
  }
1981
1957
 
@@ -2000,45 +1976,35 @@ int main(int argc, char ** argv) {
2000
1976
  }
2001
1977
 
2002
1978
  if (params.ppl_stride > 0) {
2003
- fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
1979
+ LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
2004
1980
  params.n_ctx, params.n_ctx + params.ppl_stride/2);
2005
1981
  params.n_ctx += params.ppl_stride/2;
2006
1982
  }
2007
1983
 
2008
- print_build_info();
2009
-
2010
- if (params.seed == LLAMA_DEFAULT_SEED) {
2011
- params.seed = time(NULL);
2012
- }
2013
-
2014
- fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
2015
-
2016
- std::mt19937 rng(params.seed);
2017
-
2018
1984
  llama_backend_init();
2019
1985
  llama_numa_init(params.numa);
2020
1986
 
2021
- llama_model * model;
2022
- llama_context * ctx;
2023
-
2024
1987
  // load the model and apply lora adapter, if any
2025
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
1988
+ common_init_result llama_init = common_init_from_params(params);
1989
+
1990
+ llama_model * model = llama_init.model;
1991
+ llama_context * ctx = llama_init.context;
2026
1992
  if (model == NULL) {
2027
- fprintf(stderr, "%s: error: unable to load model\n", __func__);
1993
+ LOG_ERR("%s: unable to load model\n", __func__);
2028
1994
  return 1;
2029
1995
  }
2030
1996
 
2031
1997
  const int n_ctx_train = llama_n_ctx_train(model);
2032
1998
 
2033
1999
  if (params.n_ctx > n_ctx_train) {
2034
- fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
2000
+ LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
2035
2001
  __func__, n_ctx_train, params.n_ctx);
2036
2002
  }
2037
2003
 
2038
2004
  // print system information
2039
2005
  {
2040
- fprintf(stderr, "\n");
2041
- fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
2006
+ LOG_INF("\n");
2007
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2042
2008
  }
2043
2009
 
2044
2010
  struct results_perplexity results;
@@ -2054,8 +2020,8 @@ int main(int argc, char ** argv) {
2054
2020
  results = perplexity(ctx, params, n_ctx);
2055
2021
  }
2056
2022
 
2057
- llama_print_timings(ctx);
2058
- write_logfile(ctx, params, model, results);
2023
+ LOG("\n");
2024
+ llama_perf_context_print(ctx);
2059
2025
 
2060
2026
  llama_free(ctx);
2061
2027
  llama_free_model(model);