@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,614 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <cmath>
5
+ #include <cstdio>
6
+ #include <string>
7
+ #include <vector>
8
+ #include <set>
9
+
10
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
11
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
12
+
13
+ struct seq_draft {
14
+ bool active = false;
15
+ bool drafting = false;
16
+ bool skip = false;
17
+
18
+ int i_batch_dft = 0;
19
+ std::vector<int> i_batch_tgt;
20
+
21
+ std::vector<llama_token> tokens;
22
+ std::vector<std::vector<llama_token_data>> dists;
23
+
24
+ struct llama_sampling_context * ctx_sampling;
25
+ };
26
+
27
+ int main(int argc, char ** argv) {
28
+ gpt_params params;
29
+
30
+ if (gpt_params_parse(argc, argv, params) == false) {
31
+ return 1;
32
+ }
33
+
34
+ if (params.model_draft.empty()) {
35
+ fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
36
+ return 1;
37
+ }
38
+
39
+ // max number of parallel drafting sequences (i.e. tree branches)
40
+ const int n_seq_dft = params.n_parallel;
41
+
42
+ // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
43
+ const float p_split = params.p_split;
44
+
45
+ if (params.seed == LLAMA_DEFAULT_SEED) {
46
+ params.seed = time(NULL);
47
+ }
48
+ std::default_random_engine rng(params.seed);
49
+ std::uniform_real_distribution<> u_dist;
50
+
51
+ #ifndef LOG_DISABLE_LOGS
52
+ log_set_target(log_filename_generator("speculative", "log"));
53
+ LOG_TEE("Log start\n");
54
+ log_dump_cmdline(argc, argv);
55
+ #endif // LOG_DISABLE_LOGS
56
+
57
+ // init llama.cpp
58
+ llama_backend_init();
59
+ llama_numa_init(params.numa);
60
+
61
+ llama_model * model_tgt = NULL;
62
+ llama_model * model_dft = NULL;
63
+
64
+ llama_context * ctx_tgt = NULL;
65
+ llama_context * ctx_dft = NULL;
66
+
67
+ // load the target model
68
+ std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
69
+
70
+ // load the draft model
71
+ params.model = params.model_draft;
72
+ params.n_gpu_layers = params.n_gpu_layers_draft;
73
+ if (params.n_threads_draft > 0) {
74
+ params.n_threads = params.n_threads_draft;
75
+ }
76
+ params.n_threads_batch = params.n_threads_batch_draft;
77
+ std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
78
+
79
+ const bool vocab_type_tgt = llama_vocab_type(model_tgt);
80
+ LOG("vocab_type tgt: %d\n", vocab_type_tgt);
81
+
82
+ const bool vocab_type_dft = llama_vocab_type(model_dft);
83
+ LOG("vocab_type dft: %d\n", vocab_type_dft);
84
+
85
+ if (vocab_type_tgt != vocab_type_dft) {
86
+ fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
87
+ fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
88
+ return 1;
89
+ }
90
+
91
+ if (
92
+ llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
93
+ llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
94
+ llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
95
+ llama_token_eos(model_tgt) != llama_token_eos(model_dft)
96
+ ) {
97
+ fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
98
+ return 1;
99
+ }
100
+
101
+ {
102
+ const int n_vocab_tgt = llama_n_vocab(model_tgt);
103
+ const int n_vocab_dft = llama_n_vocab(model_dft);
104
+ const int vocab_diff = n_vocab_tgt > n_vocab_dft
105
+ ? n_vocab_tgt - n_vocab_dft
106
+ : n_vocab_dft - n_vocab_tgt;
107
+
108
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
109
+ fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
110
+ fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
111
+ n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
112
+ return 1;
113
+ }
114
+
115
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
116
+ const char * token_text_tgt = llama_token_get_text(model_tgt, i);
117
+ const char * token_text_dft = llama_token_get_text(model_dft, i);
118
+ if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
119
+ fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
120
+ fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
121
+ llama_token_to_piece(ctx_tgt, i).c_str(),
122
+ llama_token_to_piece(ctx_dft, i).c_str());
123
+ return 1;
124
+ }
125
+ }
126
+ }
127
+
128
+
129
+ // Tokenize the prompt
130
+ std::vector<llama_token> inp;
131
+ inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
132
+
133
+ const int max_context_size = llama_n_ctx(ctx_tgt);
134
+ const int max_tokens_list_size = max_context_size - 4;
135
+
136
+ if ((int) inp.size() > max_tokens_list_size) {
137
+ fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
138
+ return 1;
139
+ }
140
+
141
+ fprintf(stderr, "\n\n");
142
+
143
+ for (auto id : inp) {
144
+ fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
145
+ }
146
+
147
+ fflush(stderr);
148
+
149
+ const int n_input = inp.size();
150
+
151
+ const auto t_enc_start = ggml_time_us();
152
+
153
+ // eval the prompt with both models
154
+ llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
155
+ llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
156
+ llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
157
+
158
+ const auto t_enc_end = ggml_time_us();
159
+
160
+ // the 2 models should have the same vocab
161
+ //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
162
+
163
+ // how many tokens to draft each time
164
+ int n_draft = params.n_draft;
165
+
166
+ int n_predict = 0;
167
+ int n_drafted = 0;
168
+ int n_accept = 0;
169
+
170
+ int n_past_tgt = inp.size();
171
+ int n_past_dft = inp.size();
172
+
173
+ // used to determine end of generation
174
+ bool has_eos = false;
175
+
176
+ // target model sampling context
177
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
178
+
179
+ // draft sequence data
180
+ std::vector<seq_draft> drafts(n_seq_dft);
181
+
182
+ params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
183
+ if (params.sparams.temp == 0) {
184
+ params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
185
+ }
186
+
187
+ for (int s = 0; s < n_seq_dft; ++s) {
188
+ drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
189
+ }
190
+
191
+ llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
192
+ llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
193
+
194
+ const auto t_dec_start = ggml_time_us();
195
+
196
+ // sample from the last token of the prompt
197
+ drafts[0].i_batch_tgt.resize(1);
198
+ drafts[0].i_batch_tgt[0] = 0;
199
+
200
+ while (true) {
201
+ std::set<int> active_seqs = {};
202
+
203
+ // print current draft sequences
204
+ for (int s = 0; s < n_seq_dft; ++s) {
205
+ if (!drafts[s].active) {
206
+ continue;
207
+ }
208
+
209
+ active_seqs.insert(s);
210
+ const auto & tokens = drafts[s].tokens;
211
+
212
+ LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
213
+ }
214
+
215
+ int i_dft = 0;
216
+ int s_keep = 0;
217
+
218
+ llama_token token_id;
219
+ std::string token_str;
220
+
221
+ // loop until we fail to accept a drafted token or we run out of drafted tokens
222
+ while (true) {
223
+
224
+ // check if the target token matches any of the drafts
225
+ // for stochastic sampling, attempt to match the token with the drafted tokens
226
+ {
227
+ bool accept = false;
228
+ if (params.sparams.temp > 0) {
229
+ // stochastic verification
230
+
231
+ llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
232
+ llama_sample_softmax(ctx_tgt, &dist_tgt);
233
+ float p_tgt = 0, p_dft = 0;
234
+
235
+ // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
236
+
237
+ while (active_seqs.size() > 0) {
238
+ // randomly select a sequence to verify from active sequences
239
+ std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
240
+ int s = *std::next(active_seqs.begin(), u_int_dist(rng));
241
+ if (i_dft >= (int) drafts[s].tokens.size()) {
242
+ drafts[s].active = false;
243
+ active_seqs.erase(s);
244
+ continue;
245
+ }
246
+ if (accept) {
247
+ // if we already accepted a token, we can skip the rest
248
+ if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
249
+ drafts[s].active = false;
250
+ active_seqs.erase(s);
251
+ }
252
+ continue;
253
+ }
254
+ LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
255
+ float r = u_dist(rng);
256
+ llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
257
+ // acquire the token probabilities assigned by the draft and target models
258
+ for (size_t i = 0; i < dist_tgt.size; i++) {
259
+ if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
260
+ p_tgt = dist_tgt.data[i].p;
261
+ }
262
+ if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
263
+ p_dft = dist_dft.data[i].p;
264
+ }
265
+ if (p_tgt && p_dft) {
266
+ break;
267
+ }
268
+ }
269
+ LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
270
+ if (r <= p_tgt / p_dft) {
271
+ s_keep = s;
272
+ accept = true;
273
+ token_id = drafts[s].tokens[i_dft];
274
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
275
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
276
+
277
+ LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
278
+ break;
279
+ } else {
280
+ LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
281
+ drafts[s].active = false;
282
+
283
+ // calculate residual probability
284
+ GGML_ASSERT(dist_tgt.sorted);
285
+ GGML_ASSERT(dist_dft.sorted);
286
+ float sum_probs = 0.0f;
287
+
288
+ // sort dist by id
289
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
290
+ return a.id < b.id;
291
+ });
292
+ std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
293
+ return a.id < b.id;
294
+ });
295
+
296
+ for (size_t i = 0; i < dist_tgt.size; i++) {
297
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
298
+ sum_probs += dist_tgt.data[i].p;
299
+ }
300
+ for (size_t i = 0; i < dist_tgt.size; i++) {
301
+ dist_tgt.data[i].p /= sum_probs;
302
+ }
303
+
304
+ // sort dist_tgt by p desc
305
+ std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
306
+ return a.p > b.p;
307
+ });
308
+ }
309
+
310
+ active_seqs.erase(s);
311
+ for(int i = 0; i < n_seq_dft; i++) {
312
+ if (i == s) {
313
+ continue;
314
+ }
315
+ if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
316
+ // synchronize active status for sequences with the same drafted token
317
+ drafts[i].active = drafts[i].active && accept;
318
+ if (!drafts[i].active) {
319
+ active_seqs.erase(s);
320
+ }
321
+ }
322
+ }
323
+ }
324
+
325
+ if (!accept) {
326
+ // all drafted tokens were rejected
327
+ // sample from the target model
328
+ LOG("all drafted tokens were rejected, sampling from residual distribution\n");
329
+ token_id = llama_sample_token(ctx_tgt, &dist_tgt);
330
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
331
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
332
+ }
333
+
334
+ } else {
335
+ // greedy verification
336
+
337
+ // sample from the target model
338
+ LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
339
+ token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
340
+
341
+ llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
342
+
343
+ //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
344
+
345
+ token_str = llama_token_to_piece(ctx_tgt, token_id);
346
+
347
+ for (int s = 0; s < n_seq_dft; ++s) {
348
+ if (!drafts[s].active) {
349
+ continue;
350
+ }
351
+
352
+ if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
353
+ LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
354
+
355
+ s_keep = s;
356
+ accept = true;
357
+ } else {
358
+ drafts[s].active = false;
359
+ }
360
+ }
361
+ }
362
+
363
+ if (llama_token_is_eog(model_tgt, token_id)) {
364
+ has_eos = true;
365
+ }
366
+ ++n_predict;
367
+
368
+ if (accept) {
369
+ ++n_accept;
370
+ ++n_past_tgt;
371
+ ++n_past_dft;
372
+ ++i_dft;
373
+ if (params.use_color) {
374
+ // Color token according to its origin sequence
375
+ printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
376
+ } else {
377
+ printf("%s", token_str.c_str());
378
+ }
379
+ fflush(stdout);
380
+ continue;
381
+ } else {
382
+ printf("%s", token_str.c_str());
383
+ fflush(stdout);
384
+ break;
385
+ }
386
+ }
387
+ }
388
+
389
+ {
390
+ LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
391
+
392
+ // TODO: simplify
393
+ {
394
+ LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
395
+
396
+ llama_kv_cache_seq_keep(ctx_dft, s_keep);
397
+ llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
398
+ llama_kv_cache_seq_keep(ctx_dft, 0);
399
+
400
+ llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
401
+ llama_kv_cache_seq_keep(ctx_tgt, s_keep);
402
+ llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
403
+ llama_kv_cache_seq_keep(ctx_tgt, 0);
404
+ }
405
+
406
+ for (int s = 0; s < n_seq_dft; ++s) {
407
+ drafts[s].active = false;
408
+ drafts[s].tokens.clear();
409
+ drafts[s].i_batch_tgt.clear();
410
+ drafts[s].dists.clear();
411
+ }
412
+ // note: will be erased after the speculation phase
413
+ drafts[0].tokens.push_back(token_id);
414
+ drafts[0].dists.push_back(std::vector<llama_token_data>());
415
+ drafts[0].i_batch_tgt.push_back(0);
416
+
417
+ llama_batch_clear(batch_dft);
418
+ llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
419
+
420
+ llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
421
+ // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
422
+ llama_decode(ctx_dft, batch_dft);
423
+
424
+ ++n_past_dft;
425
+ }
426
+
427
+ if (n_predict > params.n_predict || has_eos) {
428
+ break;
429
+ }
430
+
431
+ llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
432
+
433
+ int n_seq_cur = 1;
434
+ int n_past_cur = n_past_dft;
435
+
436
+ for (int s = 0; s < n_seq_dft; ++s) {
437
+ drafts[s].active = false;
438
+ drafts[s].drafting = false;
439
+ }
440
+ drafts[0].active = true;
441
+ drafts[0].drafting = true;
442
+ drafts[0].i_batch_dft = 0;
443
+
444
+ llama_batch_clear(batch_tgt);
445
+ llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
446
+
447
+ // sample n_draft tokens from the draft model using tree-based sampling
448
+ for (int i = 0; i < n_draft; ++i) {
449
+ batch_dft.n_tokens = 0;
450
+
451
+ for (int s = 0; s < n_seq_dft; ++s) {
452
+ drafts[s].skip = false;
453
+ }
454
+
455
+ for (int s = 0; s < n_seq_dft; ++s) {
456
+ if (!drafts[s].drafting || drafts[s].skip) {
457
+ continue;
458
+ }
459
+
460
+ llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
461
+
462
+ const auto & cur_p = drafts[s].ctx_sampling->cur;
463
+
464
+ for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
465
+ LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
466
+ k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
467
+ }
468
+
469
+ std::vector<int> sa(1, s);
470
+
471
+ // attempt to split the branch if the probability is high enough
472
+ for (int f = 1; f < 8; ++f) {
473
+ if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
474
+ LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
475
+
476
+ llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
477
+ llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
478
+
479
+ // all previous tokens from this branch are now also part of the new branch
480
+ for (int t = 0; t < batch_tgt.n_tokens; ++t) {
481
+ for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
482
+ if (batch_tgt.seq_id[t][p] == s) {
483
+ batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
484
+ batch_tgt.n_seq_id[t]++;
485
+ break;
486
+ }
487
+ }
488
+ }
489
+
490
+ // copy the draft state
491
+ drafts[n_seq_cur].active = true;
492
+ drafts[n_seq_cur].drafting = true;
493
+ drafts[n_seq_cur].skip = true;
494
+
495
+ drafts[n_seq_cur].tokens = drafts[s].tokens;
496
+ drafts[n_seq_cur].dists = drafts[s].dists;
497
+ drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
498
+ drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
499
+
500
+ llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
501
+
502
+ sa.push_back(n_seq_cur);
503
+
504
+ n_seq_cur++;
505
+ } else {
506
+ break;
507
+ }
508
+ }
509
+
510
+ // add drafted token for each sequence
511
+ for (int is = 0; is < (int) sa.size(); ++is) {
512
+ const llama_token id = cur_p[is].id;
513
+
514
+ const int s = sa[is];
515
+
516
+ llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
517
+
518
+ drafts[s].tokens.push_back(id);
519
+ // save cur_p.data into drafts[s].dists
520
+ drafts[s].dists.push_back(cur_p);
521
+
522
+ // add unique drafted tokens to the target batch
523
+ drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
524
+
525
+ llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
526
+
527
+ // add the token to the batch for batched decoding with the draft model
528
+ drafts[s].i_batch_dft = batch_dft.n_tokens;
529
+
530
+ llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
531
+
532
+ if (batch_tgt.n_tokens > n_draft) {
533
+ drafts[s].drafting = false;
534
+ }
535
+ }
536
+ }
537
+
538
+ // no sequence is drafting anymore
539
+ if (batch_dft.n_tokens == 0) {
540
+ break;
541
+ }
542
+
543
+ // evaluate the drafted tokens on the draft model
544
+ llama_decode(ctx_dft, batch_dft);
545
+ ++n_past_cur;
546
+ ++n_drafted;
547
+
548
+ if (batch_tgt.n_tokens > n_draft) {
549
+ break;
550
+ }
551
+ }
552
+
553
+ // evaluate the target model on the drafted tokens
554
+ {
555
+ llama_kv_cache_seq_keep(ctx_tgt, 0);
556
+ for (int s = 1; s < n_seq_dft; ++s) {
557
+ llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
558
+ }
559
+
560
+ // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
561
+ llama_decode(ctx_tgt, batch_tgt);
562
+ ++n_past_tgt;
563
+ }
564
+
565
+ // the first token is always proposed by the target model before the speculation loop so we erase it here
566
+ for (int s = 0; s < n_seq_dft; ++s) {
567
+ if (!drafts[s].active) {
568
+ continue;
569
+ }
570
+
571
+ drafts[s].tokens.erase(drafts[s].tokens.begin());
572
+ drafts[s].dists.erase(drafts[s].dists.begin());
573
+ }
574
+ }
575
+
576
+ auto t_dec_end = ggml_time_us();
577
+
578
+ LOG_TEE("\n\n");
579
+
580
+ LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
581
+ LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
582
+
583
+ LOG_TEE("\n");
584
+ LOG_TEE("n_draft = %d\n", n_draft);
585
+ LOG_TEE("n_predict = %d\n", n_predict);
586
+ LOG_TEE("n_drafted = %d\n", n_drafted);
587
+ LOG_TEE("n_accept = %d\n", n_accept);
588
+ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
589
+
590
+ LOG_TEE("\ndraft:\n");
591
+ llama_print_timings(ctx_dft);
592
+
593
+ LOG_TEE("\ntarget:\n");
594
+ llama_print_timings(ctx_tgt);
595
+
596
+ llama_sampling_free(ctx_sampling);
597
+ for (int s = 0; s < n_seq_dft; ++s) {
598
+ llama_sampling_free(drafts[s].ctx_sampling);
599
+ }
600
+
601
+ llama_batch_free(batch_dft);
602
+
603
+ llama_free(ctx_tgt);
604
+ llama_free_model(model_tgt);
605
+
606
+ llama_free(ctx_dft);
607
+ llama_free_model(model_dft);
608
+
609
+ llama_backend_free();
610
+
611
+ fprintf(stderr, "\n\n");
612
+
613
+ return 0;
614
+ }
@@ -0,0 +1,9 @@
1
+ # MIT license
2
+ # Copyright (C) 2024 Intel Corporation
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ set(TARGET ls-sycl-device)
6
+ add_executable(${TARGET} ls-sycl-device.cpp)
7
+ install(TARGETS ${TARGET} RUNTIME)
8
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
9
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
@@ -0,0 +1,13 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+
8
+ #include "ggml-sycl.h"
9
+
10
+ int main() {
11
+ ggml_backend_sycl_print_sycl_devices();
12
+ return 0;
13
+ }
@@ -0,0 +1,5 @@
1
+ set(TARGET tokenize)
2
+ add_executable(${TARGET} tokenize.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,42 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <cmath>
5
+ #include <cstdio>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ int main(int argc, char ** argv) {
10
+ if (argc < 3 || argv[1][0] == '-') {
11
+ printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
12
+ return 1;
13
+ }
14
+
15
+ const char * model_path = argv[1];
16
+ const char * prompt = argv[2];
17
+
18
+ const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
19
+
20
+ llama_backend_init();
21
+
22
+ llama_model_params model_params = llama_model_default_params();
23
+ model_params.vocab_only = true;
24
+ llama_model * model = llama_load_model_from_file(model_path, model_params);
25
+
26
+ llama_context_params ctx_params = llama_context_default_params();
27
+ llama_context * ctx = llama_new_context_with_model(model, ctx_params);
28
+
29
+ std::vector<llama_token> tokens;
30
+
31
+ tokens = ::llama_tokenize(model, prompt, true, true);
32
+
33
+ for (int i = 0; i < (int) tokens.size(); i++) {
34
+ if (printing_ids) {
35
+ printf("%d\n", tokens[i]);
36
+ } else {
37
+ printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
38
+ }
39
+ }
40
+
41
+ return 0;
42
+ }
@@ -0,0 +1,5 @@
1
+ set(TARGET train-text-from-scratch)
2
+ add_executable(${TARGET} train-text-from-scratch.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)