@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -2,33 +2,18 @@
2
2
  #undef NDEBUG
3
3
  #endif
4
4
 
5
- #define LLAMA_API_INTERNAL
6
-
7
- #include "ggml.h"
8
- #include "llama.h"
9
- #include "grammar-parser.h"
10
- #include "json-schema-to-grammar.h"
11
5
  #include "unicode.h"
6
+ #include "llama-grammar.h"
7
+ #include "json-schema-to-grammar.h"
8
+
12
9
  #include <cassert>
13
10
  #include <string>
14
11
  #include <vector>
15
12
 
16
13
  using json = nlohmann::ordered_json;
17
14
 
18
- static llama_grammar* build_grammar(const std::string & grammar_str) {
19
- auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
20
-
21
- // Ensure we parsed correctly
22
- assert(!parsed_grammar.rules.empty());
23
-
24
- // Ensure we have a root node
25
- assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
26
-
27
- std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
28
- llama_grammar* grammar = llama_grammar_init(
29
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
30
-
31
- return grammar;
15
+ static llama_grammar * build_grammar(const std::string & grammar_str) {
16
+ return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
32
17
  }
33
18
 
34
19
  static bool test_build_grammar_fails(const std::string & grammar_str) {
@@ -45,25 +30,23 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
45
30
  }
46
31
 
47
32
  static bool match_string(const std::string & input, llama_grammar * grammar) {
48
- auto decoded = decode_utf8(input, {});
49
-
50
- const auto & code_points = decoded.first;
33
+ const auto cpts = unicode_cpts_from_utf8(input);
51
34
 
52
35
  const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
53
- llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
36
+ llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
54
37
 
55
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
56
- const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
38
+ for (const auto & cpt : cpts) {
39
+ const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
57
40
 
58
- llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
41
+ llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
59
42
 
60
- if (cur_stacks.empty()) {
43
+ if (stacks_cur.empty()) {
61
44
  // no stacks means that the grammar failed to match at this point
62
45
  return false;
63
46
  }
64
47
  }
65
48
 
66
- for (const auto & stack : cur_stacks) {
49
+ for (const auto & stack : stacks_cur) {
67
50
  if (stack.empty()) {
68
51
  // An empty stack means that the grammar has been completed
69
52
  return true;
@@ -77,12 +60,12 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
77
60
  fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
78
61
  fflush(stderr);
79
62
 
80
- auto grammar = build_grammar(grammar_str);
63
+ auto * grammar = build_grammar(grammar_str);
81
64
 
82
65
  // Save the original grammar stacks so that we can reset after every new string we want to test
83
- const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
66
+ const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
84
67
 
85
- llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
68
+ llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
86
69
 
87
70
  fprintf(stderr, " 🔵 Valid strings:\n");
88
71
 
@@ -119,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
119
102
  assert(matched);
120
103
 
121
104
  // Reset the grammar stacks
122
- cur_stacks = original_stacks;
105
+ stacks_cur = stacks_org;
123
106
  }
124
107
 
125
108
  fprintf(stderr, " 🟠 Invalid strings:\n");
@@ -139,11 +122,11 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
139
122
  assert(!matched);
140
123
 
141
124
  // Reset the grammar stacks
142
- cur_stacks = original_stacks;
125
+ stacks_cur = stacks_org;
143
126
  }
144
127
 
145
128
  // Clean up allocated memory
146
- llama_grammar_free(grammar);
129
+ llama_grammar_free_impl(grammar);
147
130
  }
148
131
  static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
149
132
  test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
@@ -503,7 +486,7 @@ static void test_special_chars() {
503
486
  "aaaaabcccc",
504
487
  "aaaabccc",
505
488
  "aaaabccccc",
506
- "🔵🟠✅❌abc❌✅🟠🔵"
489
+ "🔵🟠✅❌abc❌✅🟠🔵",
507
490
  "🔵🟠abc🟠🔵"
508
491
  }
509
492
  );
@@ -683,7 +666,8 @@ static void test_failure_missing_root() {
683
666
  term ::= number
684
667
  number ::= [0-9]+)""";
685
668
 
686
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
669
+ llama_grammar_parser parsed_grammar;
670
+ parsed_grammar.parse(grammar_str.c_str());
687
671
 
688
672
  // Ensure we parsed correctly
689
673
  assert(!parsed_grammar.rules.empty());
@@ -705,7 +689,8 @@ static void test_failure_missing_reference() {
705
689
 
706
690
  fprintf(stderr, " Expected error: ");
707
691
 
708
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
692
+ llama_grammar_parser parsed_grammar;
693
+ parsed_grammar.parse(grammar_str.c_str());
709
694
 
710
695
  // Ensure we did NOT parsed correctly
711
696
  assert(parsed_grammar.rules.empty());
@@ -3,7 +3,7 @@
3
3
  #endif
4
4
 
5
5
  #include "llama.h"
6
- #include "grammar-parser.h"
6
+ #include "llama-grammar.h"
7
7
 
8
8
  #include <cassert>
9
9
 
@@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) {
22
22
 
23
23
  static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
24
24
  uint32_t index = 0;
25
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
25
+ llama_grammar_parser parsed_grammar;
26
+ parsed_grammar.parse(grammar_bytes);
26
27
 
27
28
  std::map<uint32_t, std::string> symbol_names;
28
29
  for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
@@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
129
130
  }
130
131
  }
131
132
 
132
- static void verify_failure(const char *grammar_bytes) {
133
+ static void verify_failure(const char * grammar_bytes) {
133
134
  fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
134
- auto result = grammar_parser::parse(grammar_bytes);
135
+ llama_grammar_parser result;
136
+ result.parse(grammar_bytes);
135
137
  assert(result.rules.empty() && "should have failed");
136
138
  }
137
139
 
@@ -2,14 +2,15 @@
2
2
  #undef NDEBUG
3
3
  #endif
4
4
 
5
+ #include "json-schema-to-grammar.h"
6
+
7
+ #include "llama-grammar.h"
8
+
5
9
  #include <cassert>
6
10
  #include <fstream>
7
11
  #include <sstream>
8
12
  #include <regex>
9
13
 
10
- #include "json-schema-to-grammar.h"
11
- #include "grammar-parser.h"
12
-
13
14
  static std::string trim(const std::string & source) {
14
15
  std::string s(source);
15
16
  s.erase(0,s.find_first_not_of(" \n\r\t"));
@@ -40,7 +41,8 @@ struct TestCase {
40
41
  }
41
42
  void verify_expectation_parseable() const {
42
43
  try {
43
- auto state = grammar_parser::parse(expected_grammar.c_str());
44
+ llama_grammar_parser state;
45
+ state.parse(expected_grammar.c_str());
44
46
  if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
45
47
  throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
46
48
  }
@@ -2,16 +2,15 @@
2
2
  #undef NDEBUG
3
3
  #endif
4
4
 
5
- #define LLAMA_API_INTERNAL
6
5
  #include "llama.h"
7
- #include "grammar-parser.h"
6
+ #include "llama-grammar.h"
8
7
 
9
8
  #include <cassert>
10
9
  #include <stdexcept>
11
10
 
12
11
  int main()
13
12
  {
14
- grammar_parser::parse_state parsed_grammar;
13
+ llama_grammar_parser parsed_grammar;
15
14
 
16
15
  std::vector<std::pair<std::string, uint32_t>> expected = {
17
16
  {"expr", 2},
@@ -117,7 +116,7 @@ int main()
117
116
  llama_grammar * grammar = NULL;
118
117
  std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
119
118
 
120
- grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119
+ grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
121
120
  if (grammar == nullptr)
122
121
  {
123
122
  throw std::runtime_error("Failed to initialize llama_grammar");
@@ -174,13 +173,13 @@ int main()
174
173
  }};
175
174
 
176
175
  auto index = 0;
177
- for (auto stack : llama_grammar_get_stacks(grammar))
176
+ for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
178
177
  {
179
178
  // compare stack to expected_stack
180
179
  for (uint32_t i = 0; i < stack.size(); i++)
181
180
  {
182
- auto element = stack[i];
183
- auto expected_element = expected_stacks[index][i];
181
+ const llama_grammar_element * element = stack[i];
182
+ const llama_grammar_element & expected_element = expected_stacks[index][i];
184
183
 
185
184
  // pretty print error message before asserting
186
185
  if (expected_element.type != element->type || expected_element.value != element->value)
@@ -403,6 +402,8 @@ int main()
403
402
  delete[] candidate.code_points;
404
403
  candidate.code_points = nullptr;
405
404
  }
406
- llama_grammar_free(grammar);
405
+
406
+ llama_grammar_free_impl(grammar);
407
+
407
408
  return 0;
408
409
  }
@@ -0,0 +1,39 @@
1
+ #include "log.h"
2
+
3
+ #include <cstdlib>
4
+ #include <thread>
5
+
6
+ int main() {
7
+ const int n_thread = 8;
8
+
9
+ std::thread threads[n_thread];
10
+ for (int i = 0; i < n_thread; i++) {
11
+ threads[i] = std::thread([i]() {
12
+ const int n_msg = 1000;
13
+
14
+ for (int j = 0; j < n_msg; j++) {
15
+ const int log_type = std::rand() % 4;
16
+
17
+ switch (log_type) {
18
+ case 0: LOG_INF("Thread %d: %d\n", i, j); break;
19
+ case 1: LOG_WRN("Thread %d: %d\n", i, j); break;
20
+ case 2: LOG_ERR("Thread %d: %d\n", i, j); break;
21
+ case 3: LOG_DBG("Thread %d: %d\n", i, j); break;
22
+ default:
23
+ break;
24
+ }
25
+
26
+ if (rand () % 10 < 5) {
27
+ gpt_log_set_timestamps(gpt_log_main(), rand() % 2);
28
+ gpt_log_set_prefix (gpt_log_main(), rand() % 2);
29
+ }
30
+ }
31
+ });
32
+ }
33
+
34
+ for (int i = 0; i < n_thread; i++) {
35
+ threads[i].join();
36
+ }
37
+
38
+ return 0;
39
+ }
@@ -15,11 +15,13 @@
15
15
 
16
16
  constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
17
17
  constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
18
+ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
18
19
  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
19
20
  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
20
21
  constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
21
22
  constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
22
23
  constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
24
+ constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
23
25
 
24
26
  static const char* RESULT_STR[] = {"ok", "FAILED"};
25
27
 
@@ -144,6 +146,8 @@ int main(int argc, char * argv[]) {
144
146
  if (qfns.from_float && qfns.to_float) {
145
147
  const float total_error = total_quantization_error(qfns, test_size, test_data.data());
146
148
  const float max_quantization_error =
149
+ type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
150
+ type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
147
151
  type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
148
152
  type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
149
153
  type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
@@ -166,6 +170,8 @@ int main(int argc, char * argv[]) {
166
170
  const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
167
171
  type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
168
172
  ? MAX_DOT_PRODUCT_ERROR_LOWBIT
173
+ : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
174
+ ? MAX_DOT_PRODUCT_ERROR_TERNARY
169
175
  : MAX_DOT_PRODUCT_ERROR;
170
176
  failed = !(vec_dot_error < max_allowed_error);
171
177
  num_failed += failed;
@@ -113,7 +113,7 @@ static struct ggml_tensor * get_random_tensor_f32(
113
113
  }
114
114
 
115
115
  static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
116
- struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
116
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
117
117
 
118
118
  if (plan.work_size > 0) {
119
119
  buf.resize(plan.work_size);