@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,402 @@
1
+ #ifdef NDEBUG
2
+ #undef NDEBUG
3
+ #endif
4
+
5
+ #include "llama.cpp" // TODO: not great
6
+ #include "grammar-parser.h"
7
+
8
+ #include <cassert>
9
+
10
+ int main()
11
+ {
12
+ grammar_parser::parse_state parsed_grammar;
13
+
14
+ std::vector<std::pair<std::string, uint32_t>> expected = {
15
+ {"expr", 2},
16
+ {"expr_6", 6},
17
+ {"expr_7", 7},
18
+ {"ident", 8},
19
+ {"ident_10", 10},
20
+ {"num", 9},
21
+ {"num_11", 11},
22
+ {"root", 0},
23
+ {"root_1", 1},
24
+ {"root_5", 5},
25
+ {"term", 4},
26
+ {"ws", 3},
27
+ {"ws_12", 12},
28
+ };
29
+
30
+ std::vector<std::vector<llama_grammar_element>> expected_rules = {
31
+ {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
32
+ {
33
+ {LLAMA_GRETYPE_RULE_REF, 2},
34
+ {LLAMA_GRETYPE_CHAR, 61},
35
+ {LLAMA_GRETYPE_RULE_REF, 3},
36
+ {LLAMA_GRETYPE_RULE_REF, 4},
37
+ {LLAMA_GRETYPE_CHAR, 10},
38
+ {LLAMA_GRETYPE_END, 0},
39
+ },
40
+ {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
41
+ {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
42
+ {
43
+ {LLAMA_GRETYPE_RULE_REF, 8},
44
+ {LLAMA_GRETYPE_ALT, 0},
45
+ {LLAMA_GRETYPE_RULE_REF, 9},
46
+ {LLAMA_GRETYPE_ALT, 0},
47
+ {LLAMA_GRETYPE_CHAR, 40},
48
+ {LLAMA_GRETYPE_RULE_REF, 3},
49
+ {LLAMA_GRETYPE_RULE_REF, 2},
50
+ {LLAMA_GRETYPE_CHAR, 41},
51
+ {LLAMA_GRETYPE_RULE_REF, 3},
52
+ {LLAMA_GRETYPE_END, 0},
53
+ },
54
+ {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
55
+ {
56
+ {LLAMA_GRETYPE_CHAR, 45},
57
+ {LLAMA_GRETYPE_CHAR_ALT, 43},
58
+ {LLAMA_GRETYPE_CHAR_ALT, 42},
59
+ {LLAMA_GRETYPE_CHAR_ALT, 47},
60
+ {LLAMA_GRETYPE_RULE_REF, 4},
61
+ {LLAMA_GRETYPE_END, 0},
62
+ },
63
+ {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
64
+ {
65
+ {LLAMA_GRETYPE_CHAR, 97},
66
+ {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
67
+ {LLAMA_GRETYPE_RULE_REF, 10},
68
+ {LLAMA_GRETYPE_RULE_REF, 3},
69
+ {LLAMA_GRETYPE_END, 0},
70
+ },
71
+ {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
72
+ {
73
+ {LLAMA_GRETYPE_CHAR, 97},
74
+ {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
75
+ {LLAMA_GRETYPE_CHAR_ALT, 48},
76
+ {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
77
+ {LLAMA_GRETYPE_CHAR_ALT, 95},
78
+ {LLAMA_GRETYPE_RULE_REF, 10},
79
+ {LLAMA_GRETYPE_ALT, 0},
80
+ {LLAMA_GRETYPE_END, 0},
81
+ },
82
+ {
83
+ {LLAMA_GRETYPE_CHAR, 48},
84
+ {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
85
+ {LLAMA_GRETYPE_RULE_REF, 11},
86
+ {LLAMA_GRETYPE_ALT, 0},
87
+ {LLAMA_GRETYPE_CHAR, 48},
88
+ {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
89
+ {LLAMA_GRETYPE_END, 0},
90
+ },
91
+ {
92
+ {LLAMA_GRETYPE_CHAR, 32},
93
+ {LLAMA_GRETYPE_CHAR_ALT, 9},
94
+ {LLAMA_GRETYPE_CHAR_ALT, 10},
95
+ {LLAMA_GRETYPE_RULE_REF, 12},
96
+ {LLAMA_GRETYPE_ALT, 0},
97
+ {LLAMA_GRETYPE_END, 0},
98
+ },
99
+ };
100
+
101
+ for (auto pair : expected)
102
+ {
103
+ parsed_grammar.symbol_ids[pair.first] = pair.second;
104
+ }
105
+
106
+ for (auto rule : expected_rules)
107
+ {
108
+ parsed_grammar.rules.emplace_back();
109
+ for (auto element : rule)
110
+ {
111
+ parsed_grammar.rules.back().push_back(element);
112
+ }
113
+ }
114
+
115
+ llama_grammar *grammar = NULL;
116
+ std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
117
+ grammar = llama_grammar_init(
118
+ grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119
+
120
+ std::vector<std::vector<llama_grammar_element>> expected_stacks = {
121
+ {
122
+ {LLAMA_GRETYPE_RULE_REF, 5},
123
+ {LLAMA_GRETYPE_CHAR, 61},
124
+ {LLAMA_GRETYPE_RULE_REF, 7},
125
+ {LLAMA_GRETYPE_CHAR, 97},
126
+ },
127
+ {
128
+ {LLAMA_GRETYPE_RULE_REF, 5},
129
+ {LLAMA_GRETYPE_CHAR, 61},
130
+ {LLAMA_GRETYPE_RULE_REF, 7},
131
+ {LLAMA_GRETYPE_RULE_REF, 3},
132
+ {LLAMA_GRETYPE_CHAR, 48},
133
+ },
134
+ {
135
+ {LLAMA_GRETYPE_RULE_REF, 5},
136
+ {LLAMA_GRETYPE_CHAR, 61},
137
+ {LLAMA_GRETYPE_RULE_REF, 7},
138
+ {LLAMA_GRETYPE_RULE_REF, 3},
139
+ {LLAMA_GRETYPE_CHAR, 48},
140
+ },
141
+ {
142
+ {LLAMA_GRETYPE_RULE_REF, 5},
143
+ {LLAMA_GRETYPE_CHAR, 61},
144
+ {LLAMA_GRETYPE_RULE_REF, 7},
145
+ {LLAMA_GRETYPE_CHAR, 40},
146
+ },
147
+ {
148
+ {LLAMA_GRETYPE_CHAR, 61},
149
+ {LLAMA_GRETYPE_RULE_REF, 7},
150
+ {LLAMA_GRETYPE_CHAR, 97},
151
+ },
152
+ {
153
+ {LLAMA_GRETYPE_CHAR, 61},
154
+ {LLAMA_GRETYPE_RULE_REF, 7},
155
+ {LLAMA_GRETYPE_RULE_REF, 3},
156
+ {LLAMA_GRETYPE_CHAR, 48},
157
+ },
158
+ {
159
+ {LLAMA_GRETYPE_CHAR, 61},
160
+ {LLAMA_GRETYPE_RULE_REF, 7},
161
+ {LLAMA_GRETYPE_RULE_REF, 3},
162
+ {LLAMA_GRETYPE_CHAR, 48},
163
+ },
164
+ {
165
+ {LLAMA_GRETYPE_CHAR, 61},
166
+ {LLAMA_GRETYPE_RULE_REF, 7},
167
+ {LLAMA_GRETYPE_CHAR, 40},
168
+ }};
169
+
170
+ auto index = 0;
171
+ for (auto stack : grammar->stacks)
172
+ {
173
+ // compare stack to expected_stack
174
+ for (uint32_t i = 0; i < stack.size(); i++)
175
+ {
176
+ auto element = stack[i];
177
+ auto expected_element = expected_stacks[index][i];
178
+
179
+ // pretty print error message before asserting
180
+ if (expected_element.type != element->type || expected_element.value != element->value)
181
+ {
182
+ fprintf(stderr, "index: %d\n", index);
183
+ fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
184
+ fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
185
+ fprintf(stderr, "expected_element != actual_element\n");
186
+ }
187
+
188
+ assert(expected_element.type == element->type && expected_element.value == element->value);
189
+ }
190
+ index++;
191
+ }
192
+
193
+ std::vector<llama_grammar_candidate> next_candidates;
194
+ next_candidates.resize(24);
195
+
196
+ for (size_t i = 0; i < 24; ++i)
197
+ {
198
+ uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
199
+ cp[0] = 37 + i;
200
+ cp[1] = 0;
201
+ next_candidates[i] = {i, cp, {}};
202
+ }
203
+
204
+ std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
205
+ {
206
+ {0, 37},
207
+ {1, 38},
208
+ {2, 39},
209
+ {3, 40},
210
+ {4, 41},
211
+ {5, 42},
212
+ {6, 43},
213
+ {7, 44},
214
+ {8, 45},
215
+ {9, 46},
216
+ {10, 47},
217
+ {11, 48},
218
+ {12, 49},
219
+ {13, 50},
220
+ {14, 51},
221
+ {15, 52},
222
+ {16, 53},
223
+ {17, 54},
224
+ {18, 55},
225
+ {19, 56},
226
+ {20, 57},
227
+ {21, 58},
228
+ {22, 59},
229
+ {23, 60},
230
+ },
231
+ {
232
+ {0, 37},
233
+ {1, 38},
234
+ {2, 39},
235
+ {3, 40},
236
+ {4, 41},
237
+ {5, 42},
238
+ {6, 43},
239
+ {7, 44},
240
+ {8, 45},
241
+ {9, 46},
242
+ {10, 47},
243
+ {21, 58},
244
+ {22, 59},
245
+ {23, 60},
246
+ },
247
+ {
248
+ {0, 37},
249
+ {1, 38},
250
+ {2, 39},
251
+ {3, 40},
252
+ {4, 41},
253
+ {5, 42},
254
+ {6, 43},
255
+ {7, 44},
256
+ {8, 45},
257
+ {9, 46},
258
+ {10, 47},
259
+ {21, 58},
260
+ {22, 59},
261
+ {23, 60},
262
+ },
263
+ {
264
+ {0, 37},
265
+ {1, 38},
266
+ {2, 39},
267
+ {4, 41},
268
+ {5, 42},
269
+ {6, 43},
270
+ {7, 44},
271
+ {8, 45},
272
+ {9, 46},
273
+ {10, 47},
274
+ {11, 48},
275
+ {12, 49},
276
+ {13, 50},
277
+ {14, 51},
278
+ {15, 52},
279
+ {16, 53},
280
+ {17, 54},
281
+ {18, 55},
282
+ {19, 56},
283
+ {20, 57},
284
+ {21, 58},
285
+ {22, 59},
286
+ {23, 60},
287
+ },
288
+ {
289
+ {0, 37},
290
+ {1, 38},
291
+ {2, 39},
292
+ {3, 40},
293
+ {4, 41},
294
+ {5, 42},
295
+ {6, 43},
296
+ {7, 44},
297
+ {8, 45},
298
+ {9, 46},
299
+ {10, 47},
300
+ {11, 48},
301
+ {12, 49},
302
+ {13, 50},
303
+ {14, 51},
304
+ {15, 52},
305
+ {16, 53},
306
+ {17, 54},
307
+ {18, 55},
308
+ {19, 56},
309
+ {20, 57},
310
+ {21, 58},
311
+ {22, 59},
312
+ {23, 60},
313
+ },
314
+ {
315
+ {0, 37},
316
+ {1, 38},
317
+ {2, 39},
318
+ {3, 40},
319
+ {4, 41},
320
+ {5, 42},
321
+ {6, 43},
322
+ {7, 44},
323
+ {8, 45},
324
+ {9, 46},
325
+ {10, 47},
326
+ {21, 58},
327
+ {22, 59},
328
+ {23, 60},
329
+ },
330
+ {
331
+ {0, 37},
332
+ {1, 38},
333
+ {2, 39},
334
+ {3, 40},
335
+ {4, 41},
336
+ {5, 42},
337
+ {6, 43},
338
+ {7, 44},
339
+ {8, 45},
340
+ {9, 46},
341
+ {10, 47},
342
+ {21, 58},
343
+ {22, 59},
344
+ {23, 60},
345
+ },
346
+ {
347
+ {0, 37},
348
+ {1, 38},
349
+ {2, 39},
350
+ {4, 41},
351
+ {5, 42},
352
+ {6, 43},
353
+ {7, 44},
354
+ {8, 45},
355
+ {9, 46},
356
+ {10, 47},
357
+ {11, 48},
358
+ {12, 49},
359
+ {13, 50},
360
+ {14, 51},
361
+ {15, 52},
362
+ {16, 53},
363
+ {17, 54},
364
+ {18, 55},
365
+ {19, 56},
366
+ {20, 57},
367
+ {21, 58},
368
+ {22, 59},
369
+ {23, 60},
370
+ },
371
+ };
372
+
373
+ std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
374
+
375
+ std::vector<std::vector<llama_grammar_candidate>> all_rejects;
376
+
377
+ for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
378
+ {
379
+ rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
380
+ all_rejects.push_back(rejects);
381
+ }
382
+
383
+ index = 0;
384
+ for (auto rej : all_rejects)
385
+ {
386
+ for (uint32_t i = 0; i < rej.size(); i++)
387
+ {
388
+ auto element = rej[i];
389
+ auto expected_element = expected_reject[index][i];
390
+ assert(element.index == expected_element.first && *element.code_points == expected_element.second);
391
+ }
392
+ index++;
393
+ }
394
+
395
+ for (auto &candidate : next_candidates)
396
+ {
397
+ delete[] candidate.code_points;
398
+ candidate.code_points = nullptr;
399
+ }
400
+ delete grammar;
401
+ return 0;
402
+ }
@@ -0,0 +1,27 @@
1
+ #include "llama.h"
2
+ #include "get-model.h"
3
+
4
+ #include <cstdlib>
5
+
6
+ int main(int argc, char *argv[] ) {
7
+ auto * model_path = get_model_or_exit(argc, argv);
8
+ auto * file = fopen(model_path, "r");
9
+ if (file == nullptr) {
10
+ fprintf(stderr, "no model at '%s' found\n", model_path);
11
+ return EXIT_FAILURE;
12
+ }
13
+
14
+ fprintf(stderr, "using '%s'\n", model_path);
15
+ fclose(file);
16
+
17
+ llama_backend_init();
18
+ auto params = llama_model_params{};
19
+ params.use_mmap = false;
20
+ params.progress_callback = [](float progress, void * ctx){
21
+ (void) ctx;
22
+ return progress > 0.50;
23
+ };
24
+ auto * model = llama_load_model_from_file(model_path, params);
25
+ llama_backend_free();
26
+ return model == nullptr ? EXIT_SUCCESS : EXIT_FAILURE;
27
+ }
@@ -0,0 +1,181 @@
1
+ #include "ggml.h"
2
+
3
+ #include <cmath>
4
+ #include <cstdio>
5
+ #include <cstdlib>
6
+ #include <cassert>
7
+
8
+ #define MAX_NARGS 2
9
+
10
+ #if defined(__GNUC__)
11
+ #pragma GCC diagnostic ignored "-Wdouble-promotion"
12
+ #endif
13
+
14
+ //
15
+ // logging
16
+ //
17
+ #define GGML_DEBUG 0
18
+ #if (GGML_DEBUG >= 1)
19
+ #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
20
+ #else
21
+ #define GGML_PRINT_DEBUG(...)
22
+ #endif
23
+
24
+ #if (GGML_DEBUG >= 5)
25
+ #define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
26
+ #else
27
+ #define GGML_PRINT_DEBUG_5(...)
28
+ #endif
29
+
30
+ #if (GGML_DEBUG >= 10)
31
+ #define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
32
+ #else
33
+ #define GGML_PRINT_DEBUG_10(...)
34
+ #endif
35
+
36
+ #define GGML_PRINT(...) printf(__VA_ARGS__)
37
+
38
+
39
+ static float frand(void) {
40
+ return (float)rand()/(float)RAND_MAX;
41
+ }
42
+
43
+ static struct ggml_tensor * get_random_tensor(
44
+ struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax
45
+ ) {
46
+ struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
47
+
48
+ switch (ndims) {
49
+ case 1:
50
+ for (int i0 = 0; i0 < ne[0]; i0++) {
51
+ ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
52
+ }
53
+ break;
54
+ case 2:
55
+ for (int i1 = 0; i1 < ne[1]; i1++) {
56
+ for (int i0 = 0; i0 < ne[0]; i0++) {
57
+ ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
58
+ }
59
+ }
60
+ break;
61
+ case 3:
62
+ for (int i2 = 0; i2 < ne[2]; i2++) {
63
+ for (int i1 = 0; i1 < ne[1]; i1++) {
64
+ for (int i0 = 0; i0 < ne[0]; i0++) {
65
+ ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
66
+ }
67
+ }
68
+ }
69
+ break;
70
+ case 4:
71
+ for (int i3 = 0; i3 < ne[3]; i3++) {
72
+ for (int i2 = 0; i2 < ne[2]; i2++) {
73
+ for (int i1 = 0; i1 < ne[1]; i1++) {
74
+ for (int i0 = 0; i0 < ne[0]; i0++) {
75
+ ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
76
+ }
77
+ }
78
+ }
79
+ }
80
+ break;
81
+ default:
82
+ assert(false);
83
+ }
84
+
85
+ return result;
86
+ }
87
+
88
+ int main(void) {
89
+ struct ggml_init_params params = {
90
+ /* .mem_size = */ 1024*1024*1024,
91
+ /* .mem_buffer = */ NULL,
92
+ /* .no_alloc = */ false,
93
+ };
94
+
95
+ struct ggml_context * ctx = ggml_init(params);
96
+
97
+ int64_t ne1[4] = {4, 128, 1, 1};
98
+ int64_t ne2[4] = {4, 256, 1, 1};
99
+ int64_t ne3[4] = {128, 256, 1, 1};
100
+
101
+ struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
102
+ struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
103
+ ggml_set_param(ctx, a);
104
+ ggml_set_param(ctx, b);
105
+
106
+ struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
107
+
108
+ struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
109
+ struct ggml_tensor * d = ggml_sub(ctx, c, ab);
110
+ struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
111
+
112
+ struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
113
+ ggml_build_forward_expand(ge, e);
114
+ ggml_graph_reset(ge);
115
+
116
+ ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
117
+
118
+ const float fe = ggml_get_f32_1d(e, 0);
119
+ printf("%s: e = %.4f\n", __func__, fe);
120
+
121
+ struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
122
+
123
+ ggml_opt(ctx, opt_params, e);
124
+
125
+ ggml_graph_reset(ge);
126
+
127
+ ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
128
+
129
+ const float fe_opt = ggml_get_f32_1d(e, 0);
130
+ printf("%s: original e = %.4f\n", __func__, fe);
131
+ printf("%s: optimized e = %.4f\n", __func__, fe_opt);
132
+
133
+ const bool success = (fe_opt <= fe);
134
+ assert(success);
135
+
136
+ ggml_free(ctx);
137
+ return success ? 0 : -1;
138
+ }
139
+ // int64_t ne1[4] = {4, 128, 1, 1};
140
+ // int64_t ne2[4] = {4, 256, 1, 1};;
141
+ // int64_t ne3[4] = {128, 256, 1, 1};
142
+ // main: original e = 25890.9375
143
+ // main: optimized e = 10094.7031
144
+
145
+ // int64_t ne1[4] = {8, 128, 1, 1};
146
+ // int64_t ne2[4] = {8, 256, 1, 1};;
147
+ // int64_t ne3[4] = {128, 256, 1, 1};
148
+ // main: original e = 39429.5078
149
+ // main: optimized e = 9275.8936
150
+
151
+ // int64_t ne1[4] = {16, 128, 1, 1};
152
+ // int64_t ne2[4] = {16, 256, 1, 1};;
153
+ // int64_t ne3[4] = {128, 256, 1, 1};
154
+ // main: original e = 68371.1328
155
+ // main: optimized e = 7854.4502
156
+
157
+
158
+ // int64_t ne1[4] = {32, 128, 1, 1};
159
+ // int64_t ne2[4] = {32, 256, 1, 1};;
160
+ // int64_t ne3[4] = {128, 256, 1, 1};
161
+ // main: original e = 126061.1953
162
+ // main: optimized e = 5451.0166
163
+
164
+ // int64_t ne1[4] = {4, 1024, 1, 1};
165
+ // int64_t ne2[4] = {4, 2048, 1, 1};;
166
+ // int64_t ne3[4] = {1024, 2048, 1, 1};
167
+ // main: original e = 1620817.8750
168
+ // main: optimized e = 698387.6875
169
+
170
+ // another run on M1
171
+ // int64_t ne1[4] = {4, 1024, 1, 1};
172
+ // int64_t ne2[4] = {4, 2048, 1, 1};;
173
+ // int64_t ne3[4] = {1024, 2048, 1, 1};
174
+ // main: original e = 1629595.6250
175
+ // main: optimized e = 698169.1250
176
+
177
+ // int64_t ne1[4] = {32, 1024, 1, 1};
178
+ // int64_t ne2[4] = {32, 2048, 1, 1};;
179
+ // int64_t ne3[4] = {1024, 2048, 1, 1};
180
+ // main: original e = 8146770.5000
181
+ // main: optimized e = 651119.1250