cui-llama.rn 1.4.4 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/android/src/main/CMakeLists.txt +2 -2
  2. package/android/src/main/jni.cpp +12 -10
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/chat-template.hpp +529 -529
  12. package/cpp/chat.cpp +959 -265
  13. package/cpp/chat.h +135 -0
  14. package/cpp/common.cpp +2064 -1996
  15. package/cpp/common.h +700 -744
  16. package/cpp/ggml-alloc.c +1039 -1030
  17. package/cpp/ggml-alloc.h +1 -1
  18. package/cpp/ggml-backend-impl.h +255 -255
  19. package/cpp/ggml-backend-reg.cpp +586 -582
  20. package/cpp/ggml-backend.cpp +2004 -2002
  21. package/cpp/ggml-backend.h +354 -354
  22. package/cpp/ggml-common.h +1851 -1851
  23. package/cpp/ggml-cpp.h +39 -39
  24. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  25. package/cpp/ggml-cpu-aarch64.h +8 -8
  26. package/cpp/ggml-cpu-impl.h +531 -380
  27. package/cpp/ggml-cpu-quants.c +12527 -11517
  28. package/cpp/ggml-cpu-traits.cpp +36 -36
  29. package/cpp/ggml-cpu-traits.h +38 -38
  30. package/cpp/ggml-cpu.c +15766 -14485
  31. package/cpp/ggml-cpu.cpp +655 -633
  32. package/cpp/ggml-cpu.h +138 -135
  33. package/cpp/ggml-impl.h +567 -567
  34. package/cpp/ggml-metal-impl.h +235 -0
  35. package/cpp/ggml-metal.h +66 -66
  36. package/cpp/ggml-metal.m +5146 -5002
  37. package/cpp/ggml-opt.cpp +854 -854
  38. package/cpp/ggml-opt.h +216 -216
  39. package/cpp/ggml-quants.c +5238 -5238
  40. package/cpp/ggml-threading.h +14 -14
  41. package/cpp/ggml.c +6529 -6524
  42. package/cpp/ggml.h +2198 -2194
  43. package/cpp/gguf.cpp +1329 -1329
  44. package/cpp/gguf.h +202 -202
  45. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  46. package/cpp/json-schema-to-grammar.h +21 -22
  47. package/cpp/json.hpp +24766 -24766
  48. package/cpp/llama-adapter.cpp +347 -347
  49. package/cpp/llama-adapter.h +74 -74
  50. package/cpp/llama-arch.cpp +1513 -1492
  51. package/cpp/llama-arch.h +403 -402
  52. package/cpp/llama-batch.cpp +368 -368
  53. package/cpp/llama-batch.h +88 -88
  54. package/cpp/llama-chat.cpp +588 -587
  55. package/cpp/llama-chat.h +53 -53
  56. package/cpp/llama-context.cpp +1775 -1775
  57. package/cpp/llama-context.h +128 -128
  58. package/cpp/llama-cparams.cpp +1 -1
  59. package/cpp/llama-cparams.h +37 -37
  60. package/cpp/llama-cpp.h +30 -30
  61. package/cpp/llama-grammar.cpp +1219 -1219
  62. package/cpp/llama-grammar.h +173 -164
  63. package/cpp/llama-hparams.cpp +71 -71
  64. package/cpp/llama-hparams.h +139 -139
  65. package/cpp/llama-impl.cpp +167 -167
  66. package/cpp/llama-impl.h +61 -61
  67. package/cpp/llama-kv-cache.cpp +718 -718
  68. package/cpp/llama-kv-cache.h +219 -218
  69. package/cpp/llama-mmap.cpp +600 -590
  70. package/cpp/llama-mmap.h +68 -68
  71. package/cpp/llama-model-loader.cpp +1124 -1124
  72. package/cpp/llama-model-loader.h +167 -167
  73. package/cpp/llama-model.cpp +4087 -4023
  74. package/cpp/llama-model.h +370 -370
  75. package/cpp/llama-sampling.cpp +2558 -2525
  76. package/cpp/llama-sampling.h +32 -32
  77. package/cpp/llama-vocab.cpp +3264 -3252
  78. package/cpp/llama-vocab.h +125 -125
  79. package/cpp/llama.cpp +10284 -10137
  80. package/cpp/llama.h +1354 -1340
  81. package/cpp/log.cpp +393 -423
  82. package/cpp/log.h +132 -132
  83. package/cpp/minja/chat-template.hpp +529 -0
  84. package/cpp/minja/minja.hpp +2915 -0
  85. package/cpp/minja.hpp +2915 -2883
  86. package/cpp/rn-llama.cpp +20 -37
  87. package/cpp/rn-llama.h +12 -2
  88. package/cpp/sampling.cpp +570 -532
  89. package/cpp/sgemm.cpp +2598 -2598
  90. package/cpp/sgemm.h +14 -14
  91. package/cpp/speculative.cpp +278 -277
  92. package/cpp/speculative.h +28 -28
  93. package/package.json +1 -1
  94. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  95. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  96. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  97. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  98. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  99. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  100. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  101. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  102. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  103. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  104. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  105. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  106. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  107. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  108. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  109. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  110. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  111. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  112. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  113. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  114. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  115. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  116. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  117. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  118. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  119. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  120. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  122. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  124. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  125. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  126. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  127. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  128. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  129. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  130. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  132. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  134. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  135. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  136. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  194. package/android/src/main/build-arm64/Makefile +0 -1862
  195. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  196. package/cpp/chat.hpp +0 -55
  197. package/cpp/rn-llama.hpp +0 -913
package/cpp/sgemm.h CHANGED
@@ -1,14 +1,14 @@
1
- #pragma once
2
- #include <stdint.h>
3
- #include <stdbool.h>
4
- #ifdef __cplusplus
5
- extern "C" {
6
- #endif
7
-
8
- bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t, int64_t, int64_t,
9
- const void *, int64_t, const void *, int64_t, void *, int64_t,
10
- int, int, int);
11
-
12
- #ifdef __cplusplus
13
- }
14
- #endif
1
+ #pragma once
2
+ #include <stdint.h>
3
+ #include <stdbool.h>
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ bool llamafile_sgemm(const struct lm_ggml_compute_params * params, int64_t, int64_t, int64_t,
9
+ const void *, int64_t, const void *, int64_t, void *, int64_t,
10
+ int, int, int);
11
+
12
+ #ifdef __cplusplus
13
+ }
14
+ #endif
@@ -1,277 +1,278 @@
1
- #include "speculative.h"
2
-
3
- #include "log.h"
4
- #include "common.h"
5
- #include "sampling.h"
6
-
7
- #include <cstring>
8
-
9
- #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
10
- #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
11
-
12
- struct common_speculative {
13
- struct llama_context * ctx;
14
- struct common_sampler * smpl;
15
-
16
- llama_batch batch;
17
- llama_tokens prompt;
18
- };
19
-
20
- struct common_speculative * common_speculative_init(
21
- struct llama_context * ctx_dft) {
22
- auto * result = new common_speculative {
23
- /* .ctx = */ ctx_dft,
24
- /* .smpl = */ nullptr,
25
- /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
26
- /* .prompt = */ {},
27
- };
28
-
29
- // TODO: optimize or pass from outside?
30
- #if 0
31
- {
32
- common_params_sampling params;
33
- params.no_perf = false;
34
-
35
- params.top_k = 40;
36
- params.top_p = 0.9;
37
-
38
- params.samplers = {
39
- COMMON_SAMPLER_TYPE_TOP_K,
40
- COMMON_SAMPLER_TYPE_TOP_P,
41
- COMMON_SAMPLER_TYPE_INFILL,
42
- };
43
-
44
- result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
45
- }
46
- #else
47
- {
48
- common_params_sampling params;
49
- params.no_perf = false;
50
-
51
- params.top_k = 10;
52
-
53
- params.samplers = {
54
- COMMON_SAMPLER_TYPE_TOP_K,
55
- };
56
-
57
- result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
58
- }
59
- #endif
60
-
61
- return result;
62
- }
63
-
64
- void common_speculative_free(struct common_speculative * spec) {
65
- if (spec == nullptr) {
66
- return;
67
- }
68
-
69
- common_sampler_free(spec->smpl);
70
-
71
- llama_batch_free(spec->batch);
72
-
73
- delete spec;
74
- }
75
-
76
- bool common_speculative_are_compatible(
77
- const struct llama_context * ctx_tgt,
78
- const struct llama_context * ctx_dft) {
79
- const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
80
- const struct llama_model * model_dft = llama_get_model(ctx_dft);
81
-
82
- const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
83
- const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
84
-
85
- const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
86
- LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
87
-
88
- const bool vocab_type_dft = llama_vocab_type(vocab_dft);
89
- LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
90
-
91
- if (vocab_type_tgt != vocab_type_dft) {
92
- LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
93
- "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
94
- return false;
95
- }
96
-
97
- if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
98
- llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
99
- llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
100
- llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
101
- LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
102
- LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
103
- LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
104
- return false;
105
- }
106
-
107
- {
108
- const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
109
- const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
110
-
111
- const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
112
-
113
- if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
114
- LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
115
- "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
116
- __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
117
- return false;
118
- }
119
-
120
- for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
121
- const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
122
- const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
123
- if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
124
- LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
125
- "token %d content differs - target '%s', draft '%s'\n", __func__, i,
126
- common_token_to_piece(ctx_tgt, i).c_str(),
127
- common_token_to_piece(ctx_dft, i).c_str());
128
- return false;
129
- }
130
- }
131
- }
132
-
133
- return true;
134
- }
135
-
136
- llama_tokens common_speculative_gen_draft(
137
- struct common_speculative * spec,
138
- struct common_speculative_params params,
139
- const llama_tokens & prompt_tgt,
140
- llama_token id_last) {
141
- auto & batch = spec->batch;
142
- auto & ctx = spec->ctx;
143
- auto & smpl = spec->smpl;
144
- auto & prompt = spec->prompt;
145
-
146
- int reuse_i = 0;
147
- int reuse_n = 0;
148
-
149
- const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
150
-
151
- const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
152
-
153
- // reuse as much as possible from the old draft context
154
- // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
155
- for (int i = 0; i < (int) prompt.size(); ++i) {
156
- int cur = 0;
157
- while (i_start + cur < (int) prompt_tgt.size() &&
158
- i + cur < (int) prompt.size() &&
159
- prompt_tgt[i_start + cur] == prompt[i + cur]) {
160
- cur++;
161
- }
162
-
163
- if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
164
- reuse_i = i;
165
- reuse_n = cur;
166
- }
167
- }
168
-
169
- LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
170
-
171
- llama_tokens result;
172
- result.reserve(params.n_draft);
173
-
174
- if (reuse_n == 0) {
175
- llama_kv_cache_clear(ctx);
176
-
177
- prompt.clear();
178
- } else {
179
- // this happens when a previous draft has been discarded (for example, due to being too small), but the
180
- // target model agreed with it. in this case, we simply pass back the previous results to save compute
181
- if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
182
- for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
183
- result.push_back(prompt[i]);
184
-
185
- if (params.n_draft <= (int) result.size()) {
186
- break;
187
- }
188
- }
189
-
190
- return result;
191
- }
192
-
193
- if (reuse_i > 0) {
194
- llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
195
- llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
196
-
197
- prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
198
- }
199
-
200
- if (reuse_n < (int) prompt.size()) {
201
- llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
202
-
203
- prompt.erase(prompt.begin() + reuse_n, prompt.end());
204
- }
205
- }
206
-
207
- // prepare a batch to evaluate any new tokens in the prompt
208
- common_batch_clear(batch);
209
-
210
- for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
211
- //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
212
- common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
213
-
214
- prompt.push_back(prompt_tgt[i]);
215
- }
216
-
217
- // we should rarely end-up here during normal decoding
218
- if (batch.n_tokens > 0) {
219
- //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
220
-
221
- llama_decode(ctx, batch);
222
- }
223
-
224
- const llama_pos n_past = prompt.size();
225
-
226
- LOG_DBG("%s: n_past = %d\n", __func__, n_past);
227
-
228
- common_batch_clear(batch);
229
- common_batch_add (batch, id_last, n_past, { 0 }, true);
230
-
231
- prompt.push_back(id_last);
232
-
233
- //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
234
-
235
- llama_decode(ctx, batch);
236
-
237
- common_sampler_reset(smpl);
238
-
239
- // sample n_draft tokens from the draft model
240
- for (int i = 0; i < params.n_draft; ++i) {
241
- common_batch_clear(batch);
242
-
243
- common_sampler_sample(smpl, ctx, 0, true);
244
-
245
- const auto * cur_p = common_sampler_get_candidates(smpl);
246
-
247
- for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
248
- LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
249
- k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
250
- }
251
-
252
- // add drafted token for each sequence
253
- const llama_token id = cur_p->data[0].id;
254
-
255
- // only collect very high-confidence draft tokens
256
- if (cur_p->data[0].p < params.p_min) {
257
- break;
258
- }
259
-
260
- common_sampler_accept(smpl, id, true);
261
-
262
- result.push_back(id);
263
-
264
- if (params.n_draft <= (int) result.size()) {
265
- break;
266
- }
267
-
268
- common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
269
-
270
- // evaluate the drafted tokens on the draft model
271
- llama_decode(ctx, batch);
272
-
273
- prompt.push_back(id);
274
- }
275
-
276
- return result;
277
- }
1
+ #include "speculative.h"
2
+
3
+ #include "log.h"
4
+ #include "common.h"
5
+ #include "sampling.h"
6
+
7
+ #include <cstring>
8
+ #include <algorithm>
9
+
10
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
11
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
12
+
13
+ struct common_speculative {
14
+ struct llama_context * ctx;
15
+ struct common_sampler * smpl;
16
+
17
+ llama_batch batch;
18
+ llama_tokens prompt;
19
+ };
20
+
21
+ struct common_speculative * common_speculative_init(
22
+ struct llama_context * ctx_dft) {
23
+ auto * result = new common_speculative {
24
+ /* .ctx = */ ctx_dft,
25
+ /* .smpl = */ nullptr,
26
+ /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
27
+ /* .prompt = */ {},
28
+ };
29
+
30
+ // TODO: optimize or pass from outside?
31
+ #if 0
32
+ {
33
+ common_params_sampling params;
34
+ params.no_perf = false;
35
+
36
+ params.top_k = 40;
37
+ params.top_p = 0.9;
38
+
39
+ params.samplers = {
40
+ COMMON_SAMPLER_TYPE_TOP_K,
41
+ COMMON_SAMPLER_TYPE_TOP_P,
42
+ COMMON_SAMPLER_TYPE_INFILL,
43
+ };
44
+
45
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
46
+ }
47
+ #else
48
+ {
49
+ common_params_sampling params;
50
+ params.no_perf = false;
51
+
52
+ params.top_k = 10;
53
+
54
+ params.samplers = {
55
+ COMMON_SAMPLER_TYPE_TOP_K,
56
+ };
57
+
58
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
59
+ }
60
+ #endif
61
+
62
+ return result;
63
+ }
64
+
65
+ void common_speculative_free(struct common_speculative * spec) {
66
+ if (spec == nullptr) {
67
+ return;
68
+ }
69
+
70
+ common_sampler_free(spec->smpl);
71
+
72
+ llama_batch_free(spec->batch);
73
+
74
+ delete spec;
75
+ }
76
+
77
+ bool common_speculative_are_compatible(
78
+ const struct llama_context * ctx_tgt,
79
+ const struct llama_context * ctx_dft) {
80
+ const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
81
+ const struct llama_model * model_dft = llama_get_model(ctx_dft);
82
+
83
+ const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
84
+ const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
85
+
86
+ const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
87
+ LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
88
+
89
+ const bool vocab_type_dft = llama_vocab_type(vocab_dft);
90
+ LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
91
+
92
+ if (vocab_type_tgt != vocab_type_dft) {
93
+ LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
94
+ "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
95
+ return false;
96
+ }
97
+
98
+ if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
99
+ llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
100
+ llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
101
+ llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
102
+ LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
103
+ LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
104
+ LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
105
+ return false;
106
+ }
107
+
108
+ {
109
+ const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
110
+ const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
111
+
112
+ const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
113
+
114
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115
+ LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
116
+ "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
117
+ __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118
+ return false;
119
+ }
120
+
121
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
122
+ const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
123
+ const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
124
+ if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
125
+ LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
126
+ "token %d content differs - target '%s', draft '%s'\n", __func__, i,
127
+ common_token_to_piece(ctx_tgt, i).c_str(),
128
+ common_token_to_piece(ctx_dft, i).c_str());
129
+ return false;
130
+ }
131
+ }
132
+ }
133
+
134
+ return true;
135
+ }
136
+
137
+ llama_tokens common_speculative_gen_draft(
138
+ struct common_speculative * spec,
139
+ struct common_speculative_params params,
140
+ const llama_tokens & prompt_tgt,
141
+ llama_token id_last) {
142
+ auto & batch = spec->batch;
143
+ auto & ctx = spec->ctx;
144
+ auto & smpl = spec->smpl;
145
+ auto & prompt = spec->prompt;
146
+
147
+ int reuse_i = 0;
148
+ int reuse_n = 0;
149
+
150
+ const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
151
+
152
+ const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
153
+
154
+ // reuse as much as possible from the old draft context
155
+ // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
156
+ for (int i = 0; i < (int) prompt.size(); ++i) {
157
+ int cur = 0;
158
+ while (i_start + cur < (int) prompt_tgt.size() &&
159
+ i + cur < (int) prompt.size() &&
160
+ prompt_tgt[i_start + cur] == prompt[i + cur]) {
161
+ cur++;
162
+ }
163
+
164
+ if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
165
+ reuse_i = i;
166
+ reuse_n = cur;
167
+ }
168
+ }
169
+
170
+ LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
171
+
172
+ llama_tokens result;
173
+ result.reserve(params.n_draft);
174
+
175
+ if (reuse_n == 0) {
176
+ llama_kv_cache_clear(ctx);
177
+
178
+ prompt.clear();
179
+ } else {
180
+ // this happens when a previous draft has been discarded (for example, due to being too small), but the
181
+ // target model agreed with it. in this case, we simply pass back the previous results to save compute
182
+ if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
183
+ for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
184
+ result.push_back(prompt[i]);
185
+
186
+ if (params.n_draft <= (int) result.size()) {
187
+ break;
188
+ }
189
+ }
190
+
191
+ return result;
192
+ }
193
+
194
+ if (reuse_i > 0) {
195
+ llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
196
+ llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
197
+
198
+ prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
199
+ }
200
+
201
+ if (reuse_n < (int) prompt.size()) {
202
+ llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
203
+
204
+ prompt.erase(prompt.begin() + reuse_n, prompt.end());
205
+ }
206
+ }
207
+
208
+ // prepare a batch to evaluate any new tokens in the prompt
209
+ common_batch_clear(batch);
210
+
211
+ for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
212
+ //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
213
+ common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
214
+
215
+ prompt.push_back(prompt_tgt[i]);
216
+ }
217
+
218
+ // we should rarely end-up here during normal decoding
219
+ if (batch.n_tokens > 0) {
220
+ //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
221
+
222
+ llama_decode(ctx, batch);
223
+ }
224
+
225
+ const llama_pos n_past = prompt.size();
226
+
227
+ LOG_DBG("%s: n_past = %d\n", __func__, n_past);
228
+
229
+ common_batch_clear(batch);
230
+ common_batch_add (batch, id_last, n_past, { 0 }, true);
231
+
232
+ prompt.push_back(id_last);
233
+
234
+ //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
235
+
236
+ llama_decode(ctx, batch);
237
+
238
+ common_sampler_reset(smpl);
239
+
240
+ // sample n_draft tokens from the draft model
241
+ for (int i = 0; i < params.n_draft; ++i) {
242
+ common_batch_clear(batch);
243
+
244
+ common_sampler_sample(smpl, ctx, 0, true);
245
+
246
+ const auto * cur_p = common_sampler_get_candidates(smpl);
247
+
248
+ for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
249
+ LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
250
+ k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
251
+ }
252
+
253
+ // add drafted token for each sequence
254
+ const llama_token id = cur_p->data[0].id;
255
+
256
+ common_sampler_accept(smpl, id, true);
257
+
258
+ result.push_back(id);
259
+
260
+ if (params.n_draft <= (int) result.size()) {
261
+ break;
262
+ }
263
+
264
+ // only collect very high-confidence draft tokens
265
+ if (cur_p->data[0].p < params.p_min) {
266
+ break;
267
+ }
268
+
269
+ common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
270
+
271
+ // evaluate the drafted tokens on the draft model
272
+ llama_decode(ctx, batch);
273
+
274
+ prompt.push_back(id);
275
+ }
276
+
277
+ return result;
278
+ }
package/cpp/speculative.h CHANGED
@@ -1,28 +1,28 @@
1
- #pragma once
2
-
3
- #include "llama.h"
4
- #include "common.h"
5
-
6
- struct common_speculative;
7
-
8
- struct common_speculative_params {
9
- int n_draft = 16; // max drafted tokens
10
- int n_reuse = 256;
11
-
12
- float p_min = 0.9f; // min probabiliy required to accept a token in the draft
13
- };
14
-
15
- struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
16
-
17
- void common_speculative_free(struct common_speculative * spec);
18
-
19
- bool common_speculative_are_compatible(
20
- const struct llama_context * ctx_tgt,
21
- const struct llama_context * ctx_dft);
22
-
23
- // sample up to n_draft tokens and add them to the batch using the draft model
24
- llama_tokens common_speculative_gen_draft(
25
- struct common_speculative * spec,
26
- struct common_speculative_params params,
27
- const llama_tokens & prompt,
28
- llama_token id_last);
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "common.h"
5
+
6
+ struct common_speculative;
7
+
8
+ struct common_speculative_params {
9
+ int n_draft = 16; // max drafted tokens
10
+ int n_reuse = 256;
11
+
12
+ float p_min = 0.75f; // min probability required to accept a token in the draft
13
+ };
14
+
15
+ struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
16
+
17
+ void common_speculative_free(struct common_speculative * spec);
18
+
19
+ bool common_speculative_are_compatible(
20
+ const struct llama_context * ctx_tgt,
21
+ const struct llama_context * ctx_dft);
22
+
23
+ // sample up to n_draft tokens and add them to the batch using the draft model
24
+ llama_tokens common_speculative_gen_draft(
25
+ struct common_speculative * spec,
26
+ struct common_speculative_params params,
27
+ const llama_tokens & prompt,
28
+ llama_token id_last);