@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -0,0 +1 @@
1
+ #pragma once
@@ -1,5 +1,6 @@
1
1
  #include "llama-sampling.h"
2
2
 
3
+ #include "llama-impl.h"
3
4
  #include "llama-vocab.h"
4
5
  #include "llama-grammar.h"
5
6
 
@@ -14,6 +15,118 @@
14
15
  #include <numeric>
15
16
  #include <random>
16
17
  #include <unordered_map>
18
+ #include <stdexcept>
19
+
20
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
21
+ template<typename T>
22
+ struct ring_buffer {
23
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
24
+
25
+ T & front() {
26
+ if (sz == 0) {
27
+ throw std::runtime_error("ring buffer is empty");
28
+ }
29
+ return data[first];
30
+ }
31
+
32
+ const T & front() const {
33
+ if (sz == 0) {
34
+ throw std::runtime_error("ring buffer is empty");
35
+ }
36
+ return data[first];
37
+ }
38
+
39
+ T & back() {
40
+ if (sz == 0) {
41
+ throw std::runtime_error("ring buffer is empty");
42
+ }
43
+ return data[pos];
44
+ }
45
+
46
+ const T & back() const {
47
+ if (sz == 0) {
48
+ throw std::runtime_error("ring buffer is empty");
49
+ }
50
+ return data[pos];
51
+ }
52
+
53
+ void push_back(const T & value) {
54
+ if (capacity == 0) {
55
+ throw std::runtime_error("ring buffer: capacity is zero");
56
+ }
57
+
58
+ if (sz == capacity) {
59
+ // advance the start when buffer is full
60
+ first = (first + 1) % capacity;
61
+ } else {
62
+ sz++;
63
+ }
64
+ data[pos] = value;
65
+ pos = (pos + 1) % capacity;
66
+ }
67
+
68
+ T pop_front() {
69
+ if (sz == 0) {
70
+ throw std::runtime_error("ring buffer is empty");
71
+ }
72
+ T value = data[first];
73
+ first = (first + 1) % capacity;
74
+ sz--;
75
+ return value;
76
+ }
77
+
78
+ //T & operator[](size_t i) {
79
+ // if (i >= sz) {
80
+ // throw std::runtime_error("ring buffer: index out of bounds");
81
+ // }
82
+ // return data[(first + i) % capacity];
83
+ //}
84
+
85
+ //const T & at(size_t i) const {
86
+ // if (i >= sz) {
87
+ // throw std::runtime_error("ring buffer: index out of bounds");
88
+ // }
89
+ // return data[(first + i) % capacity];
90
+ //}
91
+
92
+ const T & rat(size_t i) const {
93
+ if (i >= sz) {
94
+ throw std::runtime_error("ring buffer: index out of bounds");
95
+ }
96
+ return data[(first + sz - i - 1) % capacity];
97
+ }
98
+
99
+ std::vector<T> to_vector() const {
100
+ std::vector<T> result;
101
+ result.reserve(sz);
102
+ for (size_t i = 0; i < sz; i++) {
103
+ result.push_back(data[(first + i) % capacity]);
104
+ }
105
+ return result;
106
+ }
107
+
108
+ void clear() {
109
+ // here only reset the status of the buffer
110
+ sz = 0;
111
+ first = 0;
112
+ pos = 0;
113
+ }
114
+
115
+ bool empty() const {
116
+ return sz == 0;
117
+ }
118
+
119
+ size_t size() const {
120
+ return sz;
121
+ }
122
+
123
+ size_t capacity = 0;
124
+ size_t sz = 0;
125
+ size_t first = 0;
126
+ size_t pos = 0;
127
+
128
+ std::vector<T> data;
129
+ };
17
130
 
18
131
  static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
19
132
  // iterator for the probabilities
@@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
144
257
  for (int i = 0; i < (int)cur_p->size; ++i) {
145
258
  const float val = cur_p->data[i].logit;
146
259
  int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
147
- ib = std::max(0, std::min(nbuckets-1, ib));
260
+ ib = std::max(0, std::min(nbuckets - 1, ib));
148
261
  bucket_idx[i] = ib;
149
262
  ++histo[ib];
150
263
  }
@@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
167
280
  for (int i = 0; i < (int)cur_p->size; ++i) {
168
281
  int j = bucket_idx[i];
169
282
  if (j >= ib) {
170
- *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
283
+ *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
171
284
  }
172
285
  }
173
286
 
174
287
  ptr = tmp_tokens.data();
175
288
  int ndone = 0;
176
- for (int j = nbuckets-1; j > ib; --j) {
289
+ for (int j = nbuckets - 1; j > ib; --j) {
177
290
  std::sort(ptr, ptr + histo[j], comp);
178
291
  ptr += histo[j];
179
292
  ndone += histo[j];
@@ -258,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
258
371
  llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
259
372
  const auto * logits = llama_get_logits_ith(ctx, idx);
260
373
 
261
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
374
+ const llama_model * model = llama_get_model(ctx);
375
+ const llama_vocab * vocab = llama_model_get_vocab(model);
376
+
377
+ const int n_vocab = llama_vocab_n_tokens(vocab);
262
378
 
263
379
  // TODO: do not allocate each time
264
380
  std::vector<llama_token_data> cur;
@@ -1332,7 +1448,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
1332
1448
  static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
1333
1449
  const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
1334
1450
 
1335
- auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
1451
+ auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
1336
1452
 
1337
1453
  // copy the state
1338
1454
  {
@@ -1368,19 +1484,19 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
1368
1484
  /* .free = */ llama_sampler_grammar_free,
1369
1485
  };
1370
1486
 
1371
- struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
1487
+ struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1372
1488
  auto * ctx = new llama_sampler_grammar;
1373
1489
 
1374
1490
  if (grammar_str != nullptr && grammar_str[0] != '\0') {
1375
1491
  *ctx = {
1376
- /* .vocab = */ &vocab,
1492
+ /* .vocab = */ vocab,
1377
1493
  /* .grammar_str = */ grammar_str,
1378
1494
  /* .grammar_root = */ grammar_root,
1379
- /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
1495
+ /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
1380
1496
  };
1381
1497
  } else {
1382
1498
  *ctx = {
1383
- /* .vocab = */ &vocab,
1499
+ /* .vocab = */ vocab,
1384
1500
  /* .grammar_str = */ {},
1385
1501
  /* .grammar_root = */ {},
1386
1502
  /* .grammar = */ nullptr,
@@ -1550,8 +1666,8 @@ struct llama_sampler_dry {
1550
1666
 
1551
1667
  // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
1552
1668
  static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
1553
- for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
1554
- std::string word = llama_detokenize(vocab, {token_id}, true);
1669
+ for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
1670
+ std::string word = vocab.detokenize({token_id}, true);
1555
1671
  if (word.find(str) != std::string::npos) {
1556
1672
  token_sequences.emplace(token_id, std::vector<llama_token>());
1557
1673
  } else {
@@ -1568,7 +1684,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
1568
1684
  }
1569
1685
  }
1570
1686
  if (match) {
1571
- std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
1687
+ std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
1572
1688
  if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
1573
1689
  tokenization.resize(max_tail_len);
1574
1690
  }
@@ -1719,7 +1835,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
1719
1835
  ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
1720
1836
  if (n > 0) {
1721
1837
  lt = k;
1722
- rt = k+n-1;
1838
+ rt = k + n - 1;
1723
1839
  }
1724
1840
  } else {
1725
1841
  // If k is inside the current Z-box, consider two cases.
@@ -1824,7 +1940,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
1824
1940
  llama_vocab dummy_vocab;
1825
1941
 
1826
1942
  // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
1827
- auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1943
+ auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
1828
1944
 
1829
1945
  // Copy the state, including the processed breakers
1830
1946
  {
@@ -1851,7 +1967,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
1851
1967
  /* .free = */ llama_sampler_dry_free,
1852
1968
  };
1853
1969
 
1854
- struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
1970
+ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
1855
1971
  int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
1856
1972
  std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
1857
1973
  const int MAX_CHAR_LEN = 40;
@@ -1878,7 +1994,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
1878
1994
  sequence_break.resize(MAX_CHAR_LEN);
1879
1995
  }
1880
1996
 
1881
- get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
1997
+ get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
1882
1998
  }
1883
1999
  }
1884
2000
 
@@ -1901,7 +2017,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
1901
2017
  // wrapper for test-sampling.cpp
1902
2018
  struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
1903
2019
  llama_vocab dummy_vocab;
1904
- auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
2020
+ auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
1905
2021
  auto * ctx = (llama_sampler_dry *) result->ctx;
1906
2022
 
1907
2023
  // Process the token-based sequence breakers
@@ -2040,7 +2156,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2040
2156
  float p_eog_sum = 0.0f;
2041
2157
 
2042
2158
  for (size_t i = 0; i < cur_p->size; ++i) {
2043
- if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2159
+ if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2044
2160
  p_eog_sum += cur_p->data[i].p;
2045
2161
  } else {
2046
2162
  p_txt_sum += cur_p->data[i].p;
@@ -2062,7 +2178,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2062
2178
  float p_sum = 0.0f;
2063
2179
 
2064
2180
  for (size_t i = 0; i < size_org; ++i) {
2065
- if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
2181
+ if (ctx->vocab->is_eog(cur_p->data[i].id)) {
2066
2182
  p_sum += cur_p->data[i].p;
2067
2183
 
2068
2184
  cur_p->data[cur_p->size++] = cur_p->data[i];
@@ -2090,17 +2206,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2090
2206
  continue;
2091
2207
  }
2092
2208
 
2093
- int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2209
+ int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2094
2210
  if (len0 < 0) {
2095
2211
  ctx->buf0.resize(len0);
2096
- len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2212
+ len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
2097
2213
  assert(len0 > 0);
2098
2214
  }
2099
2215
 
2100
- int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2216
+ int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2101
2217
  if (len1 < 0) {
2102
2218
  ctx->buf1.resize(len1);
2103
- len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2219
+ len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2104
2220
  assert(len1 > 0);
2105
2221
  }
2106
2222
 
@@ -2135,7 +2251,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2135
2251
  LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2136
2252
 
2137
2253
  for (size_t i = 0; i < size_org; ++i) {
2138
- const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2254
+ const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2139
2255
 
2140
2256
  if (cur_p->data[i].p < thold && !is_eog) {
2141
2257
  continue;
@@ -2156,7 +2272,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2156
2272
  // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2157
2273
  if (n_non_eog == 0) {
2158
2274
  cur_p->size = 1;
2159
- cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
2275
+ cur_p->data[0].id = ctx->vocab->token_eot();
2160
2276
  cur_p->data[0].logit = 1.0f;
2161
2277
 
2162
2278
  return;
@@ -2178,7 +2294,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2178
2294
  LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2179
2295
 
2180
2296
  for (size_t i = 0; i < size_org; ++i) {
2181
- const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2297
+ const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
2182
2298
 
2183
2299
  if (cur_p->data[i].p < thold && !is_eog) {
2184
2300
  continue;
@@ -2201,7 +2317,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
2201
2317
 
2202
2318
  static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2203
2319
  const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2204
- return llama_sampler_init_infill_impl(*ctx->vocab);
2320
+ return llama_sampler_init_infill(ctx->vocab);
2205
2321
  }
2206
2322
 
2207
2323
  static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -2217,14 +2333,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
2217
2333
  /* .free = */ llama_sampler_infill_free,
2218
2334
  };
2219
2335
 
2220
- struct llama_sampler * llama_sampler_init_infill_impl(
2221
- const struct llama_vocab & vocab) {
2336
+ struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2222
2337
  return new llama_sampler {
2223
2338
  /* .iface = */ &llama_sampler_infill_i,
2224
2339
  /* .ctx = */ new llama_sampler_infill {
2225
- /* .vocab = */ &vocab,
2226
- /* .buf0 = */ std::vector<char>(512),
2227
- /* .buf1 = */ std::vector<char>(512),
2340
+ /* .vocab = */ vocab,
2341
+ /* .buf0 = */ std::vector<char>(512),
2342
+ /* .buf1 = */ std::vector<char>(512),
2228
2343
  },
2229
2344
  };
2230
2345
  }
@@ -2,7 +2,9 @@
2
2
 
3
3
  // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
4
4
 
5
- #include "llama-grammar.h"
5
+ #include "llama.h"
6
+
7
+ #include <vector>
6
8
 
7
9
  struct llama_vocab;
8
10
  struct llama_grammar;
@@ -21,24 +23,6 @@ struct llama_sampler_chain {
21
23
  mutable int32_t n_sample;
22
24
  };
23
25
 
24
- struct llama_sampler * llama_sampler_init_grammar_impl(
25
- const struct llama_vocab & vocab,
26
- const char * grammar_str,
27
- const char * grammar_root);
28
-
29
- struct llama_sampler * llama_sampler_init_infill_impl(
30
- const struct llama_vocab & vocab);
31
-
32
- struct llama_sampler * llama_sampler_init_dry_impl(
33
- const struct llama_vocab & vocab,
34
- int32_t context_size,
35
- float dry_multiplier,
36
- float dry_base,
37
- int32_t dry_allowed_length,
38
- int32_t dry_penalty_last_n,
39
- const char ** seq_breakers,
40
- size_t num_breakers);
41
-
42
26
  struct llama_sampler * llama_sampler_init_dry_testing(
43
27
  int32_t context_size,
44
28
  float dry_multiplier,