@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -12,7 +12,7 @@
12
12
  #include <string>
13
13
  #include <vector>
14
14
 
15
- #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
15
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
16
16
  #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
17
17
 
18
18
  struct seq_draft {
@@ -26,22 +26,27 @@ struct seq_draft {
26
26
  std::vector<llama_token> tokens;
27
27
  std::vector<std::vector<llama_token_data>> dists;
28
28
 
29
- struct gpt_sampler * smpl = nullptr;
29
+ struct common_sampler * smpl = nullptr;
30
30
  };
31
31
 
32
32
  int main(int argc, char ** argv) {
33
- gpt_params params;
33
+ common_params params;
34
34
 
35
35
  // needed to get candidate probs even for temp <= 0.0
36
- params.sparams.n_probs = 128;
36
+ params.sampling.n_probs = 128;
37
37
 
38
- if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
38
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
39
39
  return 1;
40
40
  }
41
41
 
42
- gpt_init();
42
+ if (params.n_predict < -1) {
43
+ LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
44
+ return 1;
45
+ }
43
46
 
44
- if (params.model_draft.empty()) {
47
+ common_init();
48
+
49
+ if (params.speculative.model.empty()) {
45
50
  LOG_ERR("%s: --model-draft is required\n", __func__);
46
51
  return 1;
47
52
  }
@@ -50,9 +55,9 @@ int main(int argc, char ** argv) {
50
55
  const int n_seq_dft = params.n_parallel;
51
56
 
52
57
  // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
53
- const float p_split = params.p_split;
58
+ const float p_draft_split = params.speculative.p_split;
54
59
 
55
- std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
60
+ std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
56
61
  std::uniform_real_distribution<> u_dist;
57
62
 
58
63
  // init llama.cpp
@@ -66,19 +71,20 @@ int main(int argc, char ** argv) {
66
71
  llama_context * ctx_dft = NULL;
67
72
 
68
73
  // load the target model
69
- llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
74
+ common_init_result llama_init_tgt = common_init_from_params(params);
70
75
  model_tgt = llama_init_tgt.model;
71
76
  ctx_tgt = llama_init_tgt.context;
72
77
 
73
78
  // load the draft model
74
- params.model = params.model_draft;
75
- params.n_gpu_layers = params.n_gpu_layers_draft;
76
- if (params.draft_cpuparams.n_threads > 0) {
77
- params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
79
+ params.devices = params.speculative.devices;
80
+ params.model = params.speculative.model;
81
+ params.n_gpu_layers = params.speculative.n_gpu_layers;
82
+ if (params.speculative.cpuparams.n_threads > 0) {
83
+ params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
78
84
  }
79
85
 
80
- params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
81
- llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
86
+ params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
87
+ common_init_result llama_init_dft = common_init_from_params(params);
82
88
  model_dft = llama_init_dft.model;
83
89
  ctx_dft = llama_init_dft.context;
84
90
 
@@ -124,8 +130,8 @@ int main(int argc, char ** argv) {
124
130
  if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
125
131
  LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
126
132
  LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
127
- llama_token_to_piece(ctx_tgt, i).c_str(),
128
- llama_token_to_piece(ctx_dft, i).c_str());
133
+ common_token_to_piece(ctx_tgt, i).c_str(),
134
+ common_token_to_piece(ctx_dft, i).c_str());
129
135
  return 1;
130
136
  }
131
137
  }
@@ -134,7 +140,7 @@ int main(int argc, char ** argv) {
134
140
 
135
141
  // Tokenize the prompt
136
142
  std::vector<llama_token> inp;
137
- inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
143
+ inp = common_tokenize(ctx_tgt, params.prompt, true, true);
138
144
 
139
145
  const int max_context_size = llama_n_ctx(ctx_tgt);
140
146
  const int max_tokens_list_size = max_context_size - 4;
@@ -147,7 +153,7 @@ int main(int argc, char ** argv) {
147
153
  LOG("\n\n");
148
154
 
149
155
  for (auto id : inp) {
150
- LOG("%s", llama_token_to_piece(ctx_tgt, id).c_str());
156
+ LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
151
157
  }
152
158
 
153
159
  const int n_input = inp.size();
@@ -155,9 +161,9 @@ int main(int argc, char ** argv) {
155
161
  const auto t_enc_start = ggml_time_us();
156
162
 
157
163
  // eval the prompt with both models
158
- llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
159
- llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
160
- llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
164
+ llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
165
+ llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
166
+ llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
161
167
 
162
168
  const auto t_enc_end = ggml_time_us();
163
169
 
@@ -165,7 +171,7 @@ int main(int argc, char ** argv) {
165
171
  //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
166
172
 
167
173
  // how many tokens to draft each time
168
- int n_draft = params.n_draft;
174
+ int n_draft = params.speculative.n_max;
169
175
 
170
176
  int n_predict = 0;
171
177
  int n_drafted = 0;
@@ -178,20 +184,18 @@ int main(int argc, char ** argv) {
178
184
  bool has_eos = false;
179
185
 
180
186
  // target model sampling context (reuse the llama_context's sampling instance)
181
- struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
182
-
183
- struct llama_sampler * softmax = llama_sampler_init_softmax();
187
+ struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
184
188
 
185
189
  // draft sequence data
186
190
  std::vector<seq_draft> drafts(n_seq_dft);
187
191
 
188
192
  for (int s = 0; s < n_seq_dft; ++s) {
189
- // allocate gpt_sampler for each draft sequence
190
- drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
193
+ // allocate llama_sampler for each draft sequence
194
+ drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
191
195
  }
192
196
 
193
- llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
194
- llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
197
+ llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
198
+ llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
195
199
 
196
200
  const auto t_dec_start = ggml_time_us();
197
201
 
@@ -227,11 +231,11 @@ int main(int argc, char ** argv) {
227
231
  // for stochastic sampling, attempt to match the token with the drafted tokens
228
232
  {
229
233
  bool accept = false;
230
- if (params.sparams.temp > 0) {
234
+ if (params.sampling.temp > 0) {
231
235
  // stochastic verification
232
- gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
236
+ common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
233
237
 
234
- auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
238
+ auto & dist_tgt = *common_sampler_get_candidates(smpl);
235
239
 
236
240
  float p_tgt = 0.0f;
237
241
  float p_dft = 0.0f;
@@ -264,11 +268,12 @@ int main(int argc, char ** argv) {
264
268
  for (size_t i = 0; i < dist_tgt.size; i++) {
265
269
  if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
266
270
  p_tgt = dist_tgt.data[i].p;
271
+ break;
267
272
  }
273
+ }
274
+ for (size_t i = 0; i < dist_dft.size; i++) {
268
275
  if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
269
276
  p_dft = dist_dft.data[i].p;
270
- }
271
- if (p_tgt && p_dft) {
272
277
  break;
273
278
  }
274
279
  }
@@ -277,13 +282,13 @@ int main(int argc, char ** argv) {
277
282
  s_keep = s;
278
283
  accept = true;
279
284
  token_id = drafts[s].tokens[i_dft];
280
- token_str = llama_token_to_piece(ctx_tgt, token_id);
281
- gpt_sampler_accept(smpl, token_id, true);
285
+ token_str = common_token_to_piece(ctx_tgt, token_id);
286
+ common_sampler_accept(smpl, token_id, true);
282
287
 
283
288
  LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
284
289
  break;
285
290
  } else {
286
- LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
291
+ LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
287
292
  drafts[s].active = false;
288
293
 
289
294
  // calculate residual probability
@@ -349,19 +354,19 @@ int main(int argc, char ** argv) {
349
354
  const int idx = dist(rng);
350
355
 
351
356
  token_id = dist_tgt.data[idx].id;
352
- gpt_sampler_accept(smpl, token_id, true);
353
- token_str = llama_token_to_piece(ctx_tgt, token_id);
357
+ common_sampler_accept(smpl, token_id, true);
358
+ token_str = common_token_to_piece(ctx_tgt, token_id);
354
359
  }
355
360
  } else {
356
361
  // greedy verification
357
362
 
358
363
  // sample from the target model
359
364
  LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
360
- token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
365
+ token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
361
366
 
362
- gpt_sampler_accept(smpl, token_id, true);
367
+ common_sampler_accept(smpl, token_id, true);
363
368
 
364
- token_str = llama_token_to_piece(ctx_tgt, token_id);
369
+ token_str = common_token_to_piece(ctx_tgt, token_id);
365
370
 
366
371
  for (int s = 0; s < n_seq_dft; ++s) {
367
372
  if (!drafts[s].active) {
@@ -431,8 +436,8 @@ int main(int argc, char ** argv) {
431
436
  drafts[0].dists.push_back(std::vector<llama_token_data>());
432
437
  drafts[0].i_batch_tgt.push_back(0);
433
438
 
434
- llama_batch_clear(batch_dft);
435
- llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
439
+ common_batch_clear(batch_dft);
440
+ common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
436
441
 
437
442
  llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
438
443
  // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
@@ -441,14 +446,14 @@ int main(int argc, char ** argv) {
441
446
  ++n_past_dft;
442
447
  }
443
448
 
444
- if (n_predict > params.n_predict || has_eos) {
449
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
445
450
  break;
446
451
  }
447
452
 
448
453
  if (drafts[0].smpl) {
449
- gpt_sampler_free(drafts[0].smpl);
454
+ common_sampler_free(drafts[0].smpl);
450
455
  }
451
- drafts[0].smpl = gpt_sampler_clone(smpl);
456
+ drafts[0].smpl = common_sampler_clone(smpl);
452
457
 
453
458
  int n_seq_cur = 1;
454
459
  int n_past_cur = n_past_dft;
@@ -461,8 +466,8 @@ int main(int argc, char ** argv) {
461
466
  drafts[0].drafting = true;
462
467
  drafts[0].i_batch_dft = 0;
463
468
 
464
- llama_batch_clear(batch_tgt);
465
- llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
469
+ common_batch_clear(batch_tgt);
470
+ common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
466
471
 
467
472
  // sample n_draft tokens from the draft model using tree-based sampling
468
473
  for (int i = 0; i < n_draft; ++i) {
@@ -477,20 +482,20 @@ int main(int argc, char ** argv) {
477
482
  continue;
478
483
  }
479
484
 
480
- gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
485
+ common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
481
486
 
482
- const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
487
+ const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
483
488
 
484
489
  for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
485
490
  LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
486
- k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
491
+ k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
487
492
  }
488
493
 
489
494
  std::vector<int> sa(1, s);
490
495
 
491
496
  // attempt to split the branch if the probability is high enough
492
497
  for (int f = 1; f < 8; ++f) {
493
- if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
498
+ if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
494
499
  LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
495
500
 
496
501
  llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
@@ -518,9 +523,9 @@ int main(int argc, char ** argv) {
518
523
  drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
519
524
 
520
525
  if (drafts[n_seq_cur].smpl) {
521
- gpt_sampler_free(drafts[n_seq_cur].smpl);
526
+ common_sampler_free(drafts[n_seq_cur].smpl);
522
527
  }
523
- drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
528
+ drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
524
529
 
525
530
  sa.push_back(n_seq_cur);
526
531
 
@@ -536,7 +541,7 @@ int main(int argc, char ** argv) {
536
541
 
537
542
  const int s = sa[is];
538
543
 
539
- gpt_sampler_accept(drafts[s].smpl, id, true);
544
+ common_sampler_accept(drafts[s].smpl, id, true);
540
545
 
541
546
  drafts[s].tokens.push_back(id);
542
547
  // save cur_p.data into drafts[s].dists
@@ -545,12 +550,12 @@ int main(int argc, char ** argv) {
545
550
  // add unique drafted tokens to the target batch
546
551
  drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
547
552
 
548
- llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
553
+ common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
549
554
 
550
555
  // add the token to the batch for batched decoding with the draft model
551
556
  drafts[s].i_batch_dft = batch_dft.n_tokens;
552
557
 
553
- llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
558
+ common_batch_add(batch_dft, id, n_past_cur, { s }, true);
554
559
 
555
560
  if (batch_tgt.n_tokens > n_draft) {
556
561
  drafts[s].drafting = false;
@@ -617,14 +622,13 @@ int main(int argc, char ** argv) {
617
622
 
618
623
  LOG_INF("\n");
619
624
  LOG_INF("target:\n\n");
620
- gpt_perf_print(ctx_tgt, smpl);
625
+ common_perf_print(ctx_tgt, smpl);
621
626
 
622
- gpt_sampler_free(smpl);
627
+ common_sampler_free(smpl);
623
628
  for (int s = 0; s < n_seq_dft; ++s) {
624
- gpt_sampler_free(drafts[s].smpl);
629
+ common_sampler_free(drafts[s].smpl);
625
630
  }
626
631
 
627
- llama_sampler_free(softmax);
628
632
  llama_batch_free(batch_dft);
629
633
 
630
634
  llama_free(ctx_tgt);
@@ -0,0 +1,5 @@
1
+ set(TARGET llama-speculative-simple)
2
+ add_executable(${TARGET} speculative-simple.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
@@ -0,0 +1,265 @@
1
+ #include "arg.h"
2
+ #include "common.h"
3
+ #include "sampling.h"
4
+ #include "speculative.h"
5
+ #include "log.h"
6
+ #include "llama.h"
7
+
8
+ #include <cstdio>
9
+ #include <cstring>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ int main(int argc, char ** argv) {
14
+ common_params params;
15
+
16
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
17
+ return 1;
18
+ }
19
+
20
+ if (params.n_predict < -1) {
21
+ LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
22
+ return 1;
23
+ }
24
+
25
+ common_init();
26
+
27
+ if (params.speculative.model.empty()) {
28
+ LOG_ERR("%s: --model-draft is required\n", __func__);
29
+ return 1;
30
+ }
31
+
32
+ // init llama.cpp
33
+ llama_backend_init();
34
+ llama_numa_init(params.numa);
35
+
36
+ llama_model * model_tgt = NULL;
37
+ llama_model * model_dft = NULL;
38
+
39
+ llama_context * ctx_tgt = NULL;
40
+ llama_context * ctx_dft = NULL;
41
+
42
+ // load the target model
43
+ common_init_result llama_init_tgt = common_init_from_params(params);
44
+
45
+ model_tgt = llama_init_tgt.model;
46
+ ctx_tgt = llama_init_tgt.context;
47
+
48
+ // load the draft model
49
+ params.devices = params.speculative.devices;
50
+ params.model = params.speculative.model;
51
+ params.n_ctx = params.speculative.n_ctx;
52
+ params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
53
+ params.n_gpu_layers = params.speculative.n_gpu_layers;
54
+
55
+ if (params.speculative.cpuparams.n_threads > 0) {
56
+ params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
57
+ }
58
+
59
+ params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
60
+ common_init_result llama_init_dft = common_init_from_params(params);
61
+
62
+ model_dft = llama_init_dft.model;
63
+ ctx_dft = llama_init_dft.context;
64
+
65
+ if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
66
+ return 1;
67
+ }
68
+
69
+ // Tokenize the prompt
70
+ std::vector<llama_token> inp;
71
+ inp = common_tokenize(ctx_tgt, params.prompt, true, true);
72
+
73
+ if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
74
+ LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
75
+
76
+ return 1;
77
+ }
78
+
79
+ if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
80
+ LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
81
+
82
+ return 1;
83
+ }
84
+
85
+ LOG("\n\n");
86
+
87
+ for (auto id : inp) {
88
+ LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
89
+ }
90
+
91
+ // how many tokens to draft each time
92
+ int n_draft = params.speculative.n_max;
93
+ int n_draft_min = params.speculative.n_min;
94
+
95
+ float p_min = params.speculative.p_min;
96
+
97
+ int n_predict = 0;
98
+ int n_drafted = 0;
99
+ int n_accept = 0;
100
+
101
+ // used to determine end of generation
102
+ bool has_eos = false;
103
+
104
+ // ================================================
105
+ // everything until here is standard initialization
106
+ // the relevant stuff for speculative decoding starts here
107
+
108
+ const auto t_enc_start = ggml_time_us();
109
+
110
+ // target model sampling context
111
+ struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
112
+
113
+ // eval the prompt
114
+ llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
115
+
116
+ // note: keep the last token separate!
117
+ llama_token id_last = inp.back();
118
+
119
+ // all tokens currently in the target context
120
+ llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
121
+ prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
122
+
123
+ int n_past = inp.size() - 1;
124
+
125
+ // init the speculator
126
+ struct common_speculative_params params_spec;
127
+ params_spec.n_draft = n_draft;
128
+ params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
129
+ params_spec.p_min = p_min;
130
+
131
+ struct common_speculative * spec = common_speculative_init(ctx_dft);
132
+
133
+ llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
134
+
135
+ const auto t_enc_end = ggml_time_us();
136
+
137
+ const auto t_dec_start = ggml_time_us();
138
+
139
+ while (true) {
140
+ // optionally, generate draft tokens that can be appended to the target batch
141
+ //
142
+ // this is the most important part of the speculation. the more probable tokens that are provided here
143
+ // the better the performance will be. in theory, this computation can be performed asynchronously and even
144
+ // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
145
+ // from a cache or lookup tables.
146
+ //
147
+ llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
148
+
149
+ //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
150
+
151
+ // always have a token to evaluate from before - id_last
152
+ common_batch_clear(batch_tgt);
153
+ common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
154
+
155
+ // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
156
+ {
157
+ // do not waste time on small drafts
158
+ if (draft.size() < (size_t) n_draft_min) {
159
+ draft.clear();
160
+ }
161
+
162
+ for (size_t i = 0; i < draft.size(); ++i) {
163
+ common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
164
+ }
165
+
166
+ //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
167
+
168
+ llama_decode(ctx_tgt, batch_tgt);
169
+ }
170
+
171
+ // sample from the full target batch and return the accepted tokens based on the target sampler
172
+ //
173
+ // for each token to be accepted, the sampler would have to sample that same token
174
+ // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
175
+ // available logits from the batch and sample the next token until we run out of logits or the sampler
176
+ // disagrees with the draft
177
+ //
178
+ const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
179
+
180
+ //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
181
+
182
+ GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
183
+
184
+ n_past += ids.size() - 1;
185
+ n_drafted += draft.size(); // note: we ignore the discarded small drafts
186
+ n_accept += ids.size() - 1;
187
+ n_predict += ids.size();
188
+
189
+ // process the accepted tokens and update contexts
190
+ //
191
+ // this is the standard token post-processing that we normally do
192
+ // in this case, we do it for a group of accepted tokens at once
193
+ //
194
+ for (size_t i = 0; i < ids.size(); ++i) {
195
+ prompt_tgt.push_back(id_last);
196
+
197
+ id_last = ids[i];
198
+
199
+ if (llama_token_is_eog(model_tgt, id_last)) {
200
+ has_eos = true;
201
+ break;
202
+ }
203
+
204
+ const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
205
+
206
+ if (params.use_color && i + 1 < ids.size()) {
207
+ LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
208
+ } else {
209
+ LOG("%s", token_str.c_str());
210
+ }
211
+ }
212
+
213
+ LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
214
+
215
+ {
216
+ LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
217
+
218
+ llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
219
+ }
220
+
221
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
222
+ break;
223
+ }
224
+ }
225
+
226
+ auto t_dec_end = ggml_time_us();
227
+
228
+ const int n_input = inp.size();
229
+
230
+ LOG("\n\n");
231
+
232
+ LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
233
+ LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
234
+
235
+ LOG_INF("\n");
236
+ LOG_INF("n_draft = %d\n", n_draft);
237
+ LOG_INF("n_predict = %d\n", n_predict);
238
+ LOG_INF("n_drafted = %d\n", n_drafted);
239
+ LOG_INF("n_accept = %d\n", n_accept);
240
+ LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
241
+
242
+ LOG_INF("\n");
243
+ LOG_INF("draft:\n\n");
244
+
245
+ llama_perf_context_print(ctx_dft);
246
+
247
+ LOG_INF("\n");
248
+ LOG_INF("target:\n\n");
249
+ common_perf_print(ctx_tgt, smpl);
250
+
251
+ common_sampler_free(smpl);
252
+ common_speculative_free(spec);
253
+
254
+ llama_free(ctx_tgt);
255
+ llama_free_model(model_tgt);
256
+
257
+ llama_free(ctx_dft);
258
+ llama_free_model(model_dft);
259
+
260
+ llama_backend_free();
261
+
262
+ LOG("\n\n");
263
+
264
+ return 0;
265
+ }
@@ -2,4 +2,4 @@ set(TARGET llama-tokenize)
2
2
  add_executable(${TARGET} tokenize.cpp)
3
3
  install(TARGETS ${TARGET} RUNTIME)
4
4
  target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
- target_compile_features(${TARGET} PRIVATE cxx_std_11)
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
@@ -365,7 +365,7 @@ int main(int raw_argc, char ** raw_argv) {
365
365
  const bool parse_special = !no_parse_special;
366
366
 
367
367
  std::vector<llama_token> tokens;
368
- tokens = ::llama_tokenize(model, prompt, add_bos, parse_special);
368
+ tokens = common_tokenize(model, prompt, add_bos, parse_special);
369
369
 
370
370
  if (printing_ids) {
371
371
  printf("[");
@@ -380,7 +380,7 @@ int main(int raw_argc, char ** raw_argv) {
380
380
  } else {
381
381
  bool invalid_utf8 = false;
382
382
  printf("%6d -> '", tokens[i]);
383
- write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
383
+ write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
384
384
  if (invalid_utf8) {
385
385
  printf("' (utf-8 decode failure)\n");
386
386
  } else {
@@ -394,7 +394,7 @@ int main(int raw_argc, char ** raw_argv) {
394
394
  }
395
395
 
396
396
  if (show_token_count) {
397
- printf("Total number of tokens: %ld\n", tokens.size());
397
+ printf("Total number of tokens: %zu\n", tokens.size());
398
398
  }
399
399
  // silence valgrind
400
400
  llama_free(ctx);
@@ -0,0 +1,5 @@
1
+ set(TARGET llama-tts)
2
+ add_executable(${TARGET} tts.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)