@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,11 +1,16 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
3
+ #include "sampling.h"
4
+ #include "log.h"
2
5
  #include "llama.h"
3
6
 
4
- #include <cmath>
7
+ #include <algorithm>
5
8
  #include <cstdio>
9
+ #include <cstring>
10
+ #include <random>
11
+ #include <set>
6
12
  #include <string>
7
13
  #include <vector>
8
- #include <set>
9
14
 
10
15
  #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
11
16
  #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -21,19 +26,28 @@ struct seq_draft {
21
26
  std::vector<llama_token> tokens;
22
27
  std::vector<std::vector<llama_token_data>> dists;
23
28
 
24
- struct llama_sampling_context * ctx_sampling;
29
+ struct common_sampler * smpl = nullptr;
25
30
  };
26
31
 
27
32
  int main(int argc, char ** argv) {
28
- gpt_params params;
33
+ common_params params;
34
+
35
+ // needed to get candidate probs even for temp <= 0.0
36
+ params.sparams.n_probs = 128;
29
37
 
30
- if (!gpt_params_parse(argc, argv, params)) {
31
- gpt_params_print_usage(argc, argv, params);
38
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
32
39
  return 1;
33
40
  }
34
41
 
42
+ if (params.n_predict < -1) {
43
+ LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
44
+ return 1;
45
+ }
46
+
47
+ common_init();
48
+
35
49
  if (params.model_draft.empty()) {
36
- fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
50
+ LOG_ERR("%s: --model-draft is required\n", __func__);
37
51
  return 1;
38
52
  }
39
53
 
@@ -43,18 +57,9 @@ int main(int argc, char ** argv) {
43
57
  // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
44
58
  const float p_split = params.p_split;
45
59
 
46
- if (params.seed == LLAMA_DEFAULT_SEED) {
47
- params.seed = time(NULL);
48
- }
49
- std::default_random_engine rng(params.seed);
60
+ std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
50
61
  std::uniform_real_distribution<> u_dist;
51
62
 
52
- #ifndef LOG_DISABLE_LOGS
53
- log_set_target(log_filename_generator("speculative", "log"));
54
- LOG_TEE("Log start\n");
55
- log_dump_cmdline(argc, argv);
56
- #endif // LOG_DISABLE_LOGS
57
-
58
63
  // init llama.cpp
59
64
  llama_backend_init();
60
65
  llama_numa_init(params.numa);
@@ -66,26 +71,31 @@ int main(int argc, char ** argv) {
66
71
  llama_context * ctx_dft = NULL;
67
72
 
68
73
  // load the target model
69
- std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
74
+ common_init_result llama_init_tgt = common_init_from_params(params);
75
+ model_tgt = llama_init_tgt.model;
76
+ ctx_tgt = llama_init_tgt.context;
70
77
 
71
78
  // load the draft model
72
79
  params.model = params.model_draft;
73
80
  params.n_gpu_layers = params.n_gpu_layers_draft;
74
- if (params.n_threads_draft > 0) {
75
- params.n_threads = params.n_threads_draft;
81
+ if (params.draft_cpuparams.n_threads > 0) {
82
+ params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
76
83
  }
77
- params.n_threads_batch = params.n_threads_batch_draft;
78
- std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
84
+
85
+ params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
86
+ common_init_result llama_init_dft = common_init_from_params(params);
87
+ model_dft = llama_init_dft.model;
88
+ ctx_dft = llama_init_dft.context;
79
89
 
80
90
  const bool vocab_type_tgt = llama_vocab_type(model_tgt);
81
- LOG("vocab_type tgt: %d\n", vocab_type_tgt);
91
+ LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
82
92
 
83
93
  const bool vocab_type_dft = llama_vocab_type(model_dft);
84
- LOG("vocab_type dft: %d\n", vocab_type_dft);
94
+ LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
85
95
 
86
96
  if (vocab_type_tgt != vocab_type_dft) {
87
- fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
88
- fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
97
+ LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
98
+ LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
89
99
  return 1;
90
100
  }
91
101
 
@@ -95,7 +105,7 @@ int main(int argc, char ** argv) {
95
105
  llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
96
106
  llama_token_eos(model_tgt) != llama_token_eos(model_dft)
97
107
  ) {
98
- fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
108
+ LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
99
109
  return 1;
100
110
  }
101
111
 
@@ -107,8 +117,8 @@ int main(int argc, char ** argv) {
107
117
  : n_vocab_dft - n_vocab_tgt;
108
118
 
109
119
  if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
110
- fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
111
- fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
120
+ LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
121
+ LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
112
122
  n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
113
123
  return 1;
114
124
  }
@@ -117,10 +127,10 @@ int main(int argc, char ** argv) {
117
127
  const char * token_text_tgt = llama_token_get_text(model_tgt, i);
118
128
  const char * token_text_dft = llama_token_get_text(model_dft, i);
119
129
  if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
120
- fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
121
- fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
122
- llama_token_to_piece(ctx_tgt, i).c_str(),
123
- llama_token_to_piece(ctx_dft, i).c_str());
130
+ LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
131
+ LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
132
+ common_token_to_piece(ctx_tgt, i).c_str(),
133
+ common_token_to_piece(ctx_dft, i).c_str());
124
134
  return 1;
125
135
  }
126
136
  }
@@ -129,32 +139,30 @@ int main(int argc, char ** argv) {
129
139
 
130
140
  // Tokenize the prompt
131
141
  std::vector<llama_token> inp;
132
- inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
142
+ inp = common_tokenize(ctx_tgt, params.prompt, true, true);
133
143
 
134
144
  const int max_context_size = llama_n_ctx(ctx_tgt);
135
145
  const int max_tokens_list_size = max_context_size - 4;
136
146
 
137
147
  if ((int) inp.size() > max_tokens_list_size) {
138
- fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
148
+ LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
139
149
  return 1;
140
150
  }
141
151
 
142
- fprintf(stderr, "\n\n");
152
+ LOG("\n\n");
143
153
 
144
154
  for (auto id : inp) {
145
- fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
155
+ LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
146
156
  }
147
157
 
148
- fflush(stderr);
149
-
150
158
  const int n_input = inp.size();
151
159
 
152
160
  const auto t_enc_start = ggml_time_us();
153
161
 
154
162
  // eval the prompt with both models
155
- llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
156
- llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
157
- llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
163
+ llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
164
+ llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
165
+ llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
158
166
 
159
167
  const auto t_enc_end = ggml_time_us();
160
168
 
@@ -174,23 +182,19 @@ int main(int argc, char ** argv) {
174
182
  // used to determine end of generation
175
183
  bool has_eos = false;
176
184
 
177
- // target model sampling context
178
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
185
+ // target model sampling context (reuse the llama_context's sampling instance)
186
+ struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
179
187
 
180
188
  // draft sequence data
181
189
  std::vector<seq_draft> drafts(n_seq_dft);
182
190
 
183
- params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
184
- if (params.sparams.temp == 0) {
185
- params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
186
- }
187
-
188
191
  for (int s = 0; s < n_seq_dft; ++s) {
189
- drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
192
+ // allocate llama_sampler for each draft sequence
193
+ drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
190
194
  }
191
195
 
192
- llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
193
- llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
196
+ llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
197
+ llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
194
198
 
195
199
  const auto t_dec_start = ggml_time_us();
196
200
 
@@ -210,7 +214,7 @@ int main(int argc, char ** argv) {
210
214
  active_seqs.insert(s);
211
215
  const auto & tokens = drafts[s].tokens;
212
216
 
213
- LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
217
+ LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str());
214
218
  }
215
219
 
216
220
  int i_dft = 0;
@@ -228,12 +232,12 @@ int main(int argc, char ** argv) {
228
232
  bool accept = false;
229
233
  if (params.sparams.temp > 0) {
230
234
  // stochastic verification
235
+ common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
231
236
 
232
- llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
233
- llama_sample_softmax(ctx_tgt, &dist_tgt);
234
- float p_tgt = 0, p_dft = 0;
237
+ auto & dist_tgt = *common_sampler_get_candidates(smpl);
235
238
 
236
- // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
239
+ float p_tgt = 0.0f;
240
+ float p_dft = 0.0f;
237
241
 
238
242
  while (active_seqs.size() > 0) {
239
243
  // randomly select a sequence to verify from active sequences
@@ -252,39 +256,43 @@ int main(int argc, char ** argv) {
252
256
  }
253
257
  continue;
254
258
  }
255
- LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
259
+
260
+ LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
256
261
  float r = u_dist(rng);
257
- llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
262
+ llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
263
+
264
+ //GGML_ASSERT(dist_tgt.size <= dist_dft.size);
265
+
258
266
  // acquire the token probabilities assigned by the draft and target models
259
267
  for (size_t i = 0; i < dist_tgt.size; i++) {
260
268
  if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
261
269
  p_tgt = dist_tgt.data[i].p;
270
+ break;
262
271
  }
272
+ }
273
+ for (size_t i = 0; i < dist_dft.size; i++) {
263
274
  if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
264
275
  p_dft = dist_dft.data[i].p;
265
- }
266
- if (p_tgt && p_dft) {
267
276
  break;
268
277
  }
269
278
  }
270
- LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
279
+ LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
271
280
  if (r <= p_tgt / p_dft) {
272
281
  s_keep = s;
273
282
  accept = true;
274
283
  token_id = drafts[s].tokens[i_dft];
275
- token_str = llama_token_to_piece(ctx_tgt, token_id);
276
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
284
+ token_str = common_token_to_piece(ctx_tgt, token_id);
285
+ common_sampler_accept(smpl, token_id, true);
277
286
 
278
- LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
287
+ LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
279
288
  break;
280
289
  } else {
281
- LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
290
+ LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
282
291
  drafts[s].active = false;
283
292
 
284
293
  // calculate residual probability
285
294
  GGML_ASSERT(dist_tgt.sorted);
286
295
  GGML_ASSERT(dist_dft.sorted);
287
- float sum_probs = 0.0f;
288
296
 
289
297
  // sort dist by id
290
298
  std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
@@ -294,10 +302,18 @@ int main(int argc, char ** argv) {
294
302
  return a.id < b.id;
295
303
  });
296
304
 
305
+ float sum_probs = 0.0f;
306
+
297
307
  for (size_t i = 0; i < dist_tgt.size; i++) {
298
- dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
308
+ if (i < dist_dft.size) {
309
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
310
+ } else {
311
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
312
+ }
313
+
299
314
  sum_probs += dist_tgt.data[i].p;
300
315
  }
316
+
301
317
  for (size_t i = 0; i < dist_tgt.size; i++) {
302
318
  dist_tgt.data[i].p /= sum_probs;
303
319
  }
@@ -326,24 +342,30 @@ int main(int argc, char ** argv) {
326
342
  if (!accept) {
327
343
  // all drafted tokens were rejected
328
344
  // sample from the target model
329
- LOG("all drafted tokens were rejected, sampling from residual distribution\n");
330
- token_id = llama_sample_token(ctx_tgt, &dist_tgt);
331
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
332
- token_str = llama_token_to_piece(ctx_tgt, token_id);
333
- }
345
+ LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n");
346
+ std::vector<float> probs(dist_tgt.size);
347
+ for (size_t i = 0; i < dist_tgt.size; ++i) {
348
+ probs[i] = dist_tgt.data[i].p;
349
+ }
350
+
351
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
352
+
353
+ const int idx = dist(rng);
334
354
 
355
+ token_id = dist_tgt.data[idx].id;
356
+ common_sampler_accept(smpl, token_id, true);
357
+ token_str = common_token_to_piece(ctx_tgt, token_id);
358
+ }
335
359
  } else {
336
360
  // greedy verification
337
361
 
338
362
  // sample from the target model
339
- LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
340
- token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
341
-
342
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
363
+ LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
364
+ token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
343
365
 
344
- //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
366
+ common_sampler_accept(smpl, token_id, true);
345
367
 
346
- token_str = llama_token_to_piece(ctx_tgt, token_id);
368
+ token_str = common_token_to_piece(ctx_tgt, token_id);
347
369
 
348
370
  for (int s = 0; s < n_seq_dft; ++s) {
349
371
  if (!drafts[s].active) {
@@ -351,7 +373,7 @@ int main(int argc, char ** argv) {
351
373
  }
352
374
 
353
375
  if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
354
- LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
376
+ LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
355
377
 
356
378
  s_keep = s;
357
379
  accept = true;
@@ -373,26 +395,24 @@ int main(int argc, char ** argv) {
373
395
  ++i_dft;
374
396
  if (params.use_color) {
375
397
  // Color token according to its origin sequence
376
- printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
398
+ LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
377
399
  } else {
378
- printf("%s", token_str.c_str());
400
+ LOG("%s", token_str.c_str());
379
401
  }
380
- fflush(stdout);
381
402
  continue;
382
403
  } else {
383
- printf("%s", token_str.c_str());
384
- fflush(stdout);
404
+ LOG("%s", token_str.c_str());
385
405
  break;
386
406
  }
387
407
  }
388
408
  }
389
409
 
390
410
  {
391
- LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
411
+ LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
392
412
 
393
413
  // TODO: simplify
394
414
  {
395
- LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
415
+ LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
396
416
 
397
417
  llama_kv_cache_seq_keep(ctx_dft, s_keep);
398
418
  llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
@@ -415,21 +435,24 @@ int main(int argc, char ** argv) {
415
435
  drafts[0].dists.push_back(std::vector<llama_token_data>());
416
436
  drafts[0].i_batch_tgt.push_back(0);
417
437
 
418
- llama_batch_clear(batch_dft);
419
- llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
438
+ common_batch_clear(batch_dft);
439
+ common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
420
440
 
421
441
  llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
422
- // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
442
+ // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
423
443
  llama_decode(ctx_dft, batch_dft);
424
444
 
425
445
  ++n_past_dft;
426
446
  }
427
447
 
428
- if (n_predict > params.n_predict || has_eos) {
448
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
429
449
  break;
430
450
  }
431
451
 
432
- llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
452
+ if (drafts[0].smpl) {
453
+ common_sampler_free(drafts[0].smpl);
454
+ }
455
+ drafts[0].smpl = common_sampler_clone(smpl);
433
456
 
434
457
  int n_seq_cur = 1;
435
458
  int n_past_cur = n_past_dft;
@@ -442,8 +465,8 @@ int main(int argc, char ** argv) {
442
465
  drafts[0].drafting = true;
443
466
  drafts[0].i_batch_dft = 0;
444
467
 
445
- llama_batch_clear(batch_tgt);
446
- llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
468
+ common_batch_clear(batch_tgt);
469
+ common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
447
470
 
448
471
  // sample n_draft tokens from the draft model using tree-based sampling
449
472
  for (int i = 0; i < n_draft; ++i) {
@@ -458,21 +481,21 @@ int main(int argc, char ** argv) {
458
481
  continue;
459
482
  }
460
483
 
461
- llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
484
+ common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
462
485
 
463
- const auto & cur_p = drafts[s].ctx_sampling->cur;
486
+ const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
464
487
 
465
- for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
466
- LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
467
- k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
488
+ for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
489
+ LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
490
+ k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
468
491
  }
469
492
 
470
493
  std::vector<int> sa(1, s);
471
494
 
472
495
  // attempt to split the branch if the probability is high enough
473
496
  for (int f = 1; f < 8; ++f) {
474
- if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
475
- LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
497
+ if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
498
+ LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
476
499
 
477
500
  llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
478
501
  llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
@@ -498,7 +521,10 @@ int main(int argc, char ** argv) {
498
521
  drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
499
522
  drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
500
523
 
501
- llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
524
+ if (drafts[n_seq_cur].smpl) {
525
+ common_sampler_free(drafts[n_seq_cur].smpl);
526
+ }
527
+ drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
502
528
 
503
529
  sa.push_back(n_seq_cur);
504
530
 
@@ -510,25 +536,25 @@ int main(int argc, char ** argv) {
510
536
 
511
537
  // add drafted token for each sequence
512
538
  for (int is = 0; is < (int) sa.size(); ++is) {
513
- const llama_token id = cur_p[is].id;
539
+ const llama_token id = cur_p->data[is].id;
514
540
 
515
541
  const int s = sa[is];
516
542
 
517
- llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
543
+ common_sampler_accept(drafts[s].smpl, id, true);
518
544
 
519
545
  drafts[s].tokens.push_back(id);
520
546
  // save cur_p.data into drafts[s].dists
521
- drafts[s].dists.push_back(cur_p);
547
+ drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
522
548
 
523
549
  // add unique drafted tokens to the target batch
524
550
  drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
525
551
 
526
- llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
552
+ common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
527
553
 
528
554
  // add the token to the batch for batched decoding with the draft model
529
555
  drafts[s].i_batch_dft = batch_dft.n_tokens;
530
556
 
531
- llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
557
+ common_batch_add(batch_dft, id, n_past_cur, { s }, true);
532
558
 
533
559
  if (batch_tgt.n_tokens > n_draft) {
534
560
  drafts[s].drafting = false;
@@ -558,7 +584,7 @@ int main(int argc, char ** argv) {
558
584
  llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
559
585
  }
560
586
 
561
- // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
587
+ // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
562
588
  llama_decode(ctx_tgt, batch_tgt);
563
589
  ++n_past_tgt;
564
590
  }
@@ -576,27 +602,30 @@ int main(int argc, char ** argv) {
576
602
 
577
603
  auto t_dec_end = ggml_time_us();
578
604
 
579
- LOG_TEE("\n\n");
605
+ LOG("\n\n");
580
606
 
581
- LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
582
- LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
607
+ LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
608
+ LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
583
609
 
584
- LOG_TEE("\n");
585
- LOG_TEE("n_draft = %d\n", n_draft);
586
- LOG_TEE("n_predict = %d\n", n_predict);
587
- LOG_TEE("n_drafted = %d\n", n_drafted);
588
- LOG_TEE("n_accept = %d\n", n_accept);
589
- LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
610
+ LOG_INF("\n");
611
+ LOG_INF("n_draft = %d\n", n_draft);
612
+ LOG_INF("n_predict = %d\n", n_predict);
613
+ LOG_INF("n_drafted = %d\n", n_drafted);
614
+ LOG_INF("n_accept = %d\n", n_accept);
615
+ LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
590
616
 
591
- LOG_TEE("\ndraft:\n");
592
- llama_print_timings(ctx_dft);
617
+ LOG_INF("\n");
618
+ LOG_INF("draft:\n\n");
619
+ // TODO: print sampling/grammar timings for all drafts
620
+ llama_perf_context_print(ctx_dft);
593
621
 
594
- LOG_TEE("\ntarget:\n");
595
- llama_print_timings(ctx_tgt);
622
+ LOG_INF("\n");
623
+ LOG_INF("target:\n\n");
624
+ common_perf_print(ctx_tgt, smpl);
596
625
 
597
- llama_sampling_free(ctx_sampling);
626
+ common_sampler_free(smpl);
598
627
  for (int s = 0; s < n_seq_dft; ++s) {
599
- llama_sampling_free(drafts[s].ctx_sampling);
628
+ common_sampler_free(drafts[s].smpl);
600
629
  }
601
630
 
602
631
  llama_batch_free(batch_dft);
@@ -609,7 +638,7 @@ int main(int argc, char ** argv) {
609
638
 
610
639
  llama_backend_free();
611
640
 
612
- fprintf(stderr, "\n\n");
641
+ LOG("\n\n");
613
642
 
614
643
  return 0;
615
644
  }
@@ -4,33 +4,24 @@
4
4
  # Copyright (C) 2024 Intel Corporation
5
5
  # SPDX-License-Identifier: MIT
6
6
 
7
- INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
8
7
  source /opt/intel/oneapi/setvars.sh
9
8
 
10
- if [ $# -gt 0 ]; then
11
- GGML_SYCL_DEVICE=$1
12
- GGML_SYCL_SINGLE_GPU=1
13
- else
14
- GGML_SYCL_DEVICE=0
15
- GGML_SYCL_SINGLE_GPU=0
16
- fi
17
-
18
9
  #export GGML_SYCL_DEBUG=1
19
10
 
20
-
21
11
  #ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
22
12
 
23
- if [ $GGML_SYCL_SINGLE_GPU -eq 1 ]; then
13
+ INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
14
+ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
15
+ NGL=33
16
+ CONEXT=8192
17
+
18
+ if [ $# -gt 0 ]; then
19
+ GGML_SYCL_DEVICE=$1
24
20
  echo "use $GGML_SYCL_DEVICE as main GPU"
25
21
  #use signle GPU only
26
- ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none
22
+ ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none
23
+
27
24
  else
28
25
  #use multiple GPUs with same max compute units
29
- ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0
26
+ ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT}
30
27
  fi
31
-
32
- #use main GPU only
33
- #ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none
34
-
35
- #use multiple GPUs with same max compute units
36
- #ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0
@@ -6,4 +6,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
6
6
  @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
7
7
 
8
8
 
9
- .\build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0
9
+ .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0