@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -99,7 +99,7 @@ struct ring_buffer {
99
99
  };
100
100
 
101
101
  struct common_sampler {
102
- common_sampler_params params;
102
+ common_params_sampling params;
103
103
 
104
104
  struct llama_sampler * grmr;
105
105
  struct llama_sampler * chain;
@@ -125,7 +125,7 @@ struct common_sampler {
125
125
  }
126
126
  };
127
127
 
128
- std::string common_sampler_params::print() const {
128
+ std::string common_params_sampling::print() const {
129
129
  char result[1024];
130
130
 
131
131
  snprintf(result, sizeof(result),
@@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
141
141
  return std::string(result);
142
142
  }
143
143
 
144
- struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
144
+ struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
145
145
  llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
146
146
 
147
147
  lparams.no_perf = params.no_perf;
@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
161
161
  params.logit_bias.size(),
162
162
  params.logit_bias.data()));
163
163
 
164
- llama_sampler_chain_add(result->chain,
165
- llama_sampler_init_penalties(
166
- llama_n_vocab (model),
167
- llama_token_eos(model),
168
- llama_token_nl (model),
169
- params.penalty_last_n,
170
- params.penalty_repeat,
171
- params.penalty_freq,
172
- params.penalty_present,
173
- params.penalize_nl,
174
- params.ignore_eos));
175
-
176
164
  if (params.mirostat == 0) {
177
165
  for (const auto & cnstr : params.samplers) {
178
166
  switch (cnstr) {
179
- case COMMON_SAMPLER_TYPE_DRY:
167
+ case COMMON_SAMPLER_TYPE_DRY:
180
168
  {
181
- std::vector<const char*> c_breakers;
169
+ std::vector<const char *> c_breakers;
182
170
  c_breakers.reserve(params.dry_sequence_breakers.size());
183
- for (const auto& str : params.dry_sequence_breakers) {
171
+ for (const auto & str : params.dry_sequence_breakers) {
184
172
  c_breakers.push_back(str.c_str());
185
173
  }
186
174
 
187
175
  llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
188
176
  }
189
- break;
177
+ break;
190
178
  case COMMON_SAMPLER_TYPE_TOP_K:
191
179
  llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
192
180
  break;
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
208
196
  case COMMON_SAMPLER_TYPE_INFILL:
209
197
  llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
210
198
  break;
199
+ case COMMON_SAMPLER_TYPE_PENALTIES:
200
+ llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
201
+ break;
211
202
  default:
212
203
  GGML_ASSERT(false && "unknown sampler type");
213
204
  }
@@ -320,6 +311,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
320
311
  return cur_p.data[cur_p.selected].id;
321
312
  }
322
313
 
314
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
315
+ GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
316
+
317
+ std::vector<llama_token> result;
318
+ result.reserve(idxs.size());
319
+
320
+ size_t i = 0;
321
+ for (; i < draft.size(); i++) {
322
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
323
+
324
+ common_sampler_accept(gsmpl, id, true);
325
+
326
+ result.push_back(id);
327
+
328
+ if (draft[i] != id) {
329
+ break;
330
+ }
331
+ }
332
+
333
+ if (i == draft.size()) {
334
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
335
+
336
+ common_sampler_accept(gsmpl, id, true);
337
+
338
+ result.push_back(id);
339
+ }
340
+
341
+ return result;
342
+ }
343
+
344
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
345
+ std::vector<int> idxs(draft.size() + 1);
346
+ for (size_t i = 0; i < idxs.size(); ++i) {
347
+ idxs[i] = i;
348
+ }
349
+
350
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
351
+ }
352
+
323
353
  uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
324
354
  return llama_sampler_get_seed(gsmpl->chain);
325
355
  }
@@ -376,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
376
406
  case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
377
407
  case COMMON_SAMPLER_TYPE_XTC: return 'x';
378
408
  case COMMON_SAMPLER_TYPE_INFILL: return 'i';
409
+ case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
379
410
  default : return '?';
380
411
  }
381
412
  }
@@ -390,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
390
421
  case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391
422
  case COMMON_SAMPLER_TYPE_XTC: return "xtc";
392
423
  case COMMON_SAMPLER_TYPE_INFILL: return "infill";
424
+ case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
393
425
  default : return "";
394
426
  }
395
427
  }
@@ -404,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
404
436
  { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
405
437
  { "xtc", COMMON_SAMPLER_TYPE_XTC },
406
438
  { "infill", COMMON_SAMPLER_TYPE_INFILL },
439
+ { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
407
440
  };
408
441
 
409
442
  // since samplers names are written multiple ways
@@ -450,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
450
483
  { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
451
484
  { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
452
485
  { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
486
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
453
487
  };
454
488
 
455
489
  std::vector<common_sampler_type> samplers;
@@ -36,7 +36,7 @@ struct common_sampler;
36
36
 
37
37
  // llama_sampler API overloads
38
38
 
39
- struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
39
+ struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
40
40
 
41
41
  void common_sampler_free(struct common_sampler * gsmpl);
42
42
 
@@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
60
60
  //
61
61
  llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
62
62
 
63
+ // generalized version of common_sampler_sample
64
+ //
65
+ // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
66
+ // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
67
+ //
68
+ // common_sampler_sample_n(gsmpl, ctx, { idx }, {});
69
+ //
70
+ // is equivalent to
71
+ //
72
+ // common_sampler_sample(gsmpl, ctx, idx);
73
+ // common_sampler_accept(gsmpl, token, true);
74
+ //
75
+ // requires: idxs.size() == draft.size() + 1
76
+ //
77
+ // returns at least 1 token, up to idxs.size()
78
+ //
79
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
80
+
81
+ // assume idxs == [ 0, 1, 2, ..., draft.size() ]
82
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
83
+
63
84
  uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
64
85
 
65
86
  // helpers
@@ -0,0 +1,274 @@
1
+ #include "speculative.h"
2
+
3
+ #include "log.h"
4
+ #include "common.h"
5
+ #include "sampling.h"
6
+
7
+ #include <cstring>
8
+
9
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
10
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
11
+
12
+ struct common_speculative {
13
+ struct llama_context * ctx;
14
+ struct common_sampler * smpl;
15
+
16
+ llama_batch batch;
17
+ llama_tokens prompt;
18
+ };
19
+
20
+ struct common_speculative * common_speculative_init(
21
+ struct llama_context * ctx_dft) {
22
+ auto * result = new common_speculative {
23
+ /* .ctx = */ ctx_dft,
24
+ /* .smpl = */ nullptr,
25
+ /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
26
+ /* .prompt = */ {},
27
+ };
28
+
29
+ // TODO: optimize or pass from outside?
30
+ #if 0
31
+ {
32
+ common_params_sampling params;
33
+ params.no_perf = false;
34
+
35
+ params.top_k = 40;
36
+ params.top_p = 0.9;
37
+
38
+ params.samplers = {
39
+ COMMON_SAMPLER_TYPE_TOP_K,
40
+ COMMON_SAMPLER_TYPE_TOP_P,
41
+ COMMON_SAMPLER_TYPE_INFILL,
42
+ };
43
+
44
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
45
+ }
46
+ #else
47
+ {
48
+ common_params_sampling params;
49
+ params.no_perf = false;
50
+
51
+ params.top_k = 10;
52
+
53
+ params.samplers = {
54
+ COMMON_SAMPLER_TYPE_TOP_K,
55
+ };
56
+
57
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
58
+ }
59
+ #endif
60
+
61
+ return result;
62
+ }
63
+
64
+ void common_speculative_free(struct common_speculative * spec) {
65
+ if (spec == nullptr) {
66
+ return;
67
+ }
68
+
69
+ common_sampler_free(spec->smpl);
70
+
71
+ llama_batch_free(spec->batch);
72
+
73
+ delete spec;
74
+ }
75
+
76
+ bool common_speculative_are_compatible(
77
+ const struct llama_context * ctx_tgt,
78
+ const struct llama_context * ctx_dft) {
79
+ const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
80
+ const struct llama_model * model_dft = llama_get_model(ctx_dft);
81
+
82
+ const bool vocab_type_tgt = llama_vocab_type(model_tgt);
83
+ LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
84
+
85
+ const bool vocab_type_dft = llama_vocab_type(model_dft);
86
+ LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
87
+
88
+ if (vocab_type_tgt != vocab_type_dft) {
89
+ LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
90
+ "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
91
+ return false;
92
+ }
93
+
94
+ if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
95
+ llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
96
+ llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
97
+ llama_token_eos(model_tgt) != llama_token_eos(model_dft)) {
98
+ LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
99
+ LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt));
100
+ LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft));
101
+ return false;
102
+ }
103
+
104
+ {
105
+ const int n_vocab_tgt = llama_n_vocab(model_tgt);
106
+ const int n_vocab_dft = llama_n_vocab(model_dft);
107
+
108
+ const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
109
+
110
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
111
+ LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
112
+ "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
113
+ __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
114
+ return false;
115
+ }
116
+
117
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
118
+ const char * token_text_tgt = llama_token_get_text(model_tgt, i);
119
+ const char * token_text_dft = llama_token_get_text(model_dft, i);
120
+ if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
121
+ LOG_ERR("%s: draft model vocab must match target model to use speculation but "
122
+ "token %d content differs - target '%s', draft '%s'\n", __func__, i,
123
+ common_token_to_piece(ctx_tgt, i).c_str(),
124
+ common_token_to_piece(ctx_dft, i).c_str());
125
+ return false;
126
+ }
127
+ }
128
+ }
129
+
130
+ return true;
131
+ }
132
+
133
+ llama_tokens common_speculative_gen_draft(
134
+ struct common_speculative * spec,
135
+ struct common_speculative_params params,
136
+ const llama_tokens & prompt_tgt,
137
+ llama_token id_last) {
138
+ auto & batch = spec->batch;
139
+ auto & ctx = spec->ctx;
140
+ auto & smpl = spec->smpl;
141
+ auto & prompt = spec->prompt;
142
+
143
+ int reuse_i = 0;
144
+ int reuse_n = 0;
145
+
146
+ const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
147
+
148
+ const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
149
+
150
+ // reuse as much as possible from the old draft context
151
+ // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
152
+ for (int i = 0; i < (int) prompt.size(); ++i) {
153
+ int cur = 0;
154
+ while (i_start + cur < (int) prompt_tgt.size() &&
155
+ i + cur < (int) prompt.size() &&
156
+ prompt_tgt[i_start + cur] == prompt[i + cur]) {
157
+ cur++;
158
+ }
159
+
160
+ if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
161
+ reuse_i = i;
162
+ reuse_n = cur;
163
+ }
164
+ }
165
+
166
+ LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
167
+
168
+ llama_tokens result;
169
+ result.reserve(params.n_draft);
170
+
171
+ if (reuse_n == 0) {
172
+ llama_kv_cache_clear(ctx);
173
+
174
+ prompt.clear();
175
+ } else {
176
+ // this happens when a previous draft has been discarded (for example, due to being too small), but the
177
+ // target model agreed with it. in this case, we simply pass back the previous results to save compute
178
+ if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
179
+ for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
180
+ result.push_back(prompt[i]);
181
+
182
+ if (params.n_draft <= (int) result.size()) {
183
+ break;
184
+ }
185
+ }
186
+
187
+ return result;
188
+ }
189
+
190
+ if (reuse_i > 0) {
191
+ llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
192
+ llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
193
+
194
+ prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
195
+ }
196
+
197
+ if (reuse_n < (int) prompt.size()) {
198
+ llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
199
+
200
+ prompt.erase(prompt.begin() + reuse_n, prompt.end());
201
+ }
202
+ }
203
+
204
+ // prepare a batch to evaluate any new tokens in the prompt
205
+ common_batch_clear(batch);
206
+
207
+ for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
208
+ //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
209
+ common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
210
+
211
+ prompt.push_back(prompt_tgt[i]);
212
+ }
213
+
214
+ // we should rarely end-up here during normal decoding
215
+ if (batch.n_tokens > 0) {
216
+ //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
217
+
218
+ llama_decode(ctx, batch);
219
+ }
220
+
221
+ const llama_pos n_past = prompt.size();
222
+
223
+ LOG_DBG("%s: n_past = %d\n", __func__, n_past);
224
+
225
+ common_batch_clear(batch);
226
+ common_batch_add (batch, id_last, n_past, { 0 }, true);
227
+
228
+ prompt.push_back(id_last);
229
+
230
+ //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
231
+
232
+ llama_decode(ctx, batch);
233
+
234
+ common_sampler_reset(smpl);
235
+
236
+ // sample n_draft tokens from the draft model
237
+ for (int i = 0; i < params.n_draft; ++i) {
238
+ common_batch_clear(batch);
239
+
240
+ common_sampler_sample(smpl, ctx, 0, true);
241
+
242
+ const auto * cur_p = common_sampler_get_candidates(smpl);
243
+
244
+ for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
245
+ LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
246
+ k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
247
+ }
248
+
249
+ // add drafted token for each sequence
250
+ const llama_token id = cur_p->data[0].id;
251
+
252
+ // only collect very high-confidence draft tokens
253
+ if (cur_p->data[0].p < params.p_min) {
254
+ break;
255
+ }
256
+
257
+ common_sampler_accept(smpl, id, true);
258
+
259
+ result.push_back(id);
260
+
261
+ if (params.n_draft <= (int) result.size()) {
262
+ break;
263
+ }
264
+
265
+ common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
266
+
267
+ // evaluate the drafted tokens on the draft model
268
+ llama_decode(ctx, batch);
269
+
270
+ prompt.push_back(id);
271
+ }
272
+
273
+ return result;
274
+ }
@@ -0,0 +1,28 @@
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "common.h"
5
+
6
+ struct common_speculative;
7
+
8
+ struct common_speculative_params {
9
+ int n_draft = 16; // max drafted tokens
10
+ int n_reuse = 256;
11
+
12
+ float p_min = 0.9f; // min probabiliy required to accept a token in the draft
13
+ };
14
+
15
+ struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
16
+
17
+ void common_speculative_free(struct common_speculative * spec);
18
+
19
+ bool common_speculative_are_compatible(
20
+ const struct llama_context * ctx_tgt,
21
+ const struct llama_context * ctx_dft);
22
+
23
+ // sample up to n_draft tokens and add them to the batch using the draft model
24
+ llama_tokens common_speculative_gen_draft(
25
+ struct common_speculative * spec,
26
+ struct common_speculative_params params,
27
+ const llama_tokens & prompt,
28
+ llama_token id_last);