@fugood/llama.node 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +29 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +17 -1
  21. package/src/LlamaContext.cpp +86 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -0,0 +1,265 @@
1
+ #include "arg.h"
2
+ #include "common.h"
3
+ #include "sampling.h"
4
+ #include "speculative.h"
5
+ #include "log.h"
6
+ #include "llama.h"
7
+
8
+ #include <cstdio>
9
+ #include <cstring>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ int main(int argc, char ** argv) {
14
+ common_params params;
15
+
16
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
17
+ return 1;
18
+ }
19
+
20
+ if (params.n_predict < -1) {
21
+ LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
22
+ return 1;
23
+ }
24
+
25
+ common_init();
26
+
27
+ if (params.speculative.model.empty()) {
28
+ LOG_ERR("%s: --model-draft is required\n", __func__);
29
+ return 1;
30
+ }
31
+
32
+ // init llama.cpp
33
+ llama_backend_init();
34
+ llama_numa_init(params.numa);
35
+
36
+ llama_model * model_tgt = NULL;
37
+ llama_model * model_dft = NULL;
38
+
39
+ llama_context * ctx_tgt = NULL;
40
+ llama_context * ctx_dft = NULL;
41
+
42
+ // load the target model
43
+ common_init_result llama_init_tgt = common_init_from_params(params);
44
+
45
+ model_tgt = llama_init_tgt.model;
46
+ ctx_tgt = llama_init_tgt.context;
47
+
48
+ // load the draft model
49
+ params.devices = params.speculative.devices;
50
+ params.model = params.speculative.model;
51
+ params.n_ctx = params.speculative.n_ctx;
52
+ params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
53
+ params.n_gpu_layers = params.speculative.n_gpu_layers;
54
+
55
+ if (params.speculative.cpuparams.n_threads > 0) {
56
+ params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
57
+ }
58
+
59
+ params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
60
+ common_init_result llama_init_dft = common_init_from_params(params);
61
+
62
+ model_dft = llama_init_dft.model;
63
+ ctx_dft = llama_init_dft.context;
64
+
65
+ if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
66
+ return 1;
67
+ }
68
+
69
+ // Tokenize the prompt
70
+ std::vector<llama_token> inp;
71
+ inp = common_tokenize(ctx_tgt, params.prompt, true, true);
72
+
73
+ if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
74
+ LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
75
+
76
+ return 1;
77
+ }
78
+
79
+ if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
80
+ LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
81
+
82
+ return 1;
83
+ }
84
+
85
+ LOG("\n\n");
86
+
87
+ for (auto id : inp) {
88
+ LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
89
+ }
90
+
91
+ // how many tokens to draft each time
92
+ int n_draft = params.speculative.n_max;
93
+ int n_draft_min = params.speculative.n_min;
94
+
95
+ float p_min = params.speculative.p_min;
96
+
97
+ int n_predict = 0;
98
+ int n_drafted = 0;
99
+ int n_accept = 0;
100
+
101
+ // used to determine end of generation
102
+ bool has_eos = false;
103
+
104
+ // ================================================
105
+ // everything until here is standard initialization
106
+ // the relevant stuff for speculative decoding starts here
107
+
108
+ const auto t_enc_start = ggml_time_us();
109
+
110
+ // target model sampling context
111
+ struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
112
+
113
+ // eval the prompt
114
+ llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
115
+
116
+ // note: keep the last token separate!
117
+ llama_token id_last = inp.back();
118
+
119
+ // all tokens currently in the target context
120
+ llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
121
+ prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
122
+
123
+ int n_past = inp.size() - 1;
124
+
125
+ // init the speculator
126
+ struct common_speculative_params params_spec;
127
+ params_spec.n_draft = n_draft;
128
+ params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
129
+ params_spec.p_min = p_min;
130
+
131
+ struct common_speculative * spec = common_speculative_init(ctx_dft);
132
+
133
+ llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
134
+
135
+ const auto t_enc_end = ggml_time_us();
136
+
137
+ const auto t_dec_start = ggml_time_us();
138
+
139
+ while (true) {
140
+ // optionally, generate draft tokens that can be appended to the target batch
141
+ //
142
+ // this is the most important part of the speculation. the more probable tokens that are provided here
143
+ // the better the performance will be. in theory, this computation can be performed asynchronously and even
144
+ // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
145
+ // from a cache or lookup tables.
146
+ //
147
+ llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
148
+
149
+ //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
150
+
151
+ // always have a token to evaluate from before - id_last
152
+ common_batch_clear(batch_tgt);
153
+ common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
154
+
155
+ // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
156
+ {
157
+ // do not waste time on small drafts
158
+ if (draft.size() < (size_t) n_draft_min) {
159
+ draft.clear();
160
+ }
161
+
162
+ for (size_t i = 0; i < draft.size(); ++i) {
163
+ common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
164
+ }
165
+
166
+ //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
167
+
168
+ llama_decode(ctx_tgt, batch_tgt);
169
+ }
170
+
171
+ // sample from the full target batch and return the accepted tokens based on the target sampler
172
+ //
173
+ // for each token to be accepted, the sampler would have to sample that same token
174
+ // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
175
+ // available logits from the batch and sample the next token until we run out of logits or the sampler
176
+ // disagrees with the draft
177
+ //
178
+ const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
179
+
180
+ //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
181
+
182
+ GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
183
+
184
+ n_past += ids.size() - 1;
185
+ n_drafted += draft.size(); // note: we ignore the discarded small drafts
186
+ n_accept += ids.size() - 1;
187
+ n_predict += ids.size();
188
+
189
+ // process the accepted tokens and update contexts
190
+ //
191
+ // this is the standard token post-processing that we normally do
192
+ // in this case, we do it for a group of accepted tokens at once
193
+ //
194
+ for (size_t i = 0; i < ids.size(); ++i) {
195
+ prompt_tgt.push_back(id_last);
196
+
197
+ id_last = ids[i];
198
+
199
+ if (llama_token_is_eog(model_tgt, id_last)) {
200
+ has_eos = true;
201
+ break;
202
+ }
203
+
204
+ const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
205
+
206
+ if (params.use_color && i + 1 < ids.size()) {
207
+ LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
208
+ } else {
209
+ LOG("%s", token_str.c_str());
210
+ }
211
+ }
212
+
213
+ LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
214
+
215
+ {
216
+ LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
217
+
218
+ llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
219
+ }
220
+
221
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
222
+ break;
223
+ }
224
+ }
225
+
226
+ auto t_dec_end = ggml_time_us();
227
+
228
+ const int n_input = inp.size();
229
+
230
+ LOG("\n\n");
231
+
232
+ LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
233
+ LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
234
+
235
+ LOG_INF("\n");
236
+ LOG_INF("n_draft = %d\n", n_draft);
237
+ LOG_INF("n_predict = %d\n", n_predict);
238
+ LOG_INF("n_drafted = %d\n", n_drafted);
239
+ LOG_INF("n_accept = %d\n", n_accept);
240
+ LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
241
+
242
+ LOG_INF("\n");
243
+ LOG_INF("draft:\n\n");
244
+
245
+ llama_perf_context_print(ctx_dft);
246
+
247
+ LOG_INF("\n");
248
+ LOG_INF("target:\n\n");
249
+ common_perf_print(ctx_tgt, smpl);
250
+
251
+ common_sampler_free(smpl);
252
+ common_speculative_free(spec);
253
+
254
+ llama_free(ctx_tgt);
255
+ llama_free_model(model_tgt);
256
+
257
+ llama_free(ctx_dft);
258
+ llama_free_model(model_dft);
259
+
260
+ llama_backend_free();
261
+
262
+ LOG("\n\n");
263
+
264
+ return 0;
265
+ }
@@ -2,4 +2,4 @@ set(TARGET llama-tokenize)
2
2
  add_executable(${TARGET} tokenize.cpp)
3
3
  install(TARGETS ${TARGET} RUNTIME)
4
4
  target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
- target_compile_features(${TARGET} PRIVATE cxx_std_11)
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
@@ -394,7 +394,7 @@ int main(int raw_argc, char ** raw_argv) {
394
394
  }
395
395
 
396
396
  if (show_token_count) {
397
- printf("Total number of tokens: %ld\n", tokens.size());
397
+ printf("Total number of tokens: %zu\n", tokens.size());
398
398
  }
399
399
  // silence valgrind
400
400
  llama_free(ctx);
@@ -0,0 +1,5 @@
1
+ set(TARGET llama-tts)
2
+ add_executable(${TARGET} tts.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)