@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,221 @@
1
+ #include "ggml.h"
2
+
3
+ #include <cmath>
4
+ #include <cstdio>
5
+ #include <cstdlib>
6
+ #include <cassert>
7
+ #include <vector>
8
+
9
+ #if defined(_MSC_VER)
10
+ #pragma warning(disable: 4244 4267) // possible loss of data
11
+ #endif
12
+
13
+ #if defined(__GNUC__)
14
+ #pragma GCC diagnostic ignored "-Wdouble-promotion"
15
+ #endif
16
+
17
+ #define MAX_NARGS 3
18
+
19
+ #undef MIN
20
+ #undef MAX
21
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
22
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
23
+
24
+ #define GGML_SILU_FP16
25
+
26
+ //
27
+ // logging
28
+ //
29
+
30
+ #if (GGML_DEBUG >= 1)
31
+ #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
32
+ #else
33
+ #define GGML_PRINT_DEBUG(...)
34
+ #endif
35
+
36
+ #if (GGML_DEBUG >= 5)
37
+ #define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
38
+ #else
39
+ #define GGML_PRINT_DEBUG_5(...)
40
+ #endif
41
+
42
+ #if (GGML_DEBUG >= 10)
43
+ #define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
44
+ #else
45
+ #define GGML_PRINT_DEBUG_10(...)
46
+ #endif
47
+
48
+ #define GGML_PRINT(...) printf(__VA_ARGS__)
49
+
50
+ static float frand(void) {
51
+ return (float)rand()/(float)RAND_MAX;
52
+ }
53
+
54
+ static int irand(int n) {
55
+ if (n == 0) return 0;
56
+ return rand()%n;
57
+ }
58
+
59
+ static void get_random_dims(int64_t * dims, int ndims) {
60
+ dims[0] = dims[1] = dims[2] = dims[3] = 1;
61
+
62
+ for (int i = 0; i < ndims; i++) {
63
+ dims[i] = 1 + irand(4);
64
+ }
65
+ }
66
+
67
+ static struct ggml_tensor * get_random_tensor_f32(
68
+ struct ggml_context * ctx0,
69
+ int ndims,
70
+ const int64_t ne[],
71
+ float fmin,
72
+ float fmax) {
73
+ struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
74
+
75
+ switch (ndims) {
76
+ case 1:
77
+ for (int i0 = 0; i0 < ne[0]; i0++) {
78
+ ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
79
+ }
80
+ break;
81
+ case 2:
82
+ for (int i1 = 0; i1 < ne[1]; i1++) {
83
+ for (int i0 = 0; i0 < ne[0]; i0++) {
84
+ ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
85
+ }
86
+ }
87
+ break;
88
+ case 3:
89
+ for (int i2 = 0; i2 < ne[2]; i2++) {
90
+ for (int i1 = 0; i1 < ne[1]; i1++) {
91
+ for (int i0 = 0; i0 < ne[0]; i0++) {
92
+ ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
93
+ }
94
+ }
95
+ }
96
+ break;
97
+ case 4:
98
+ for (int i3 = 0; i3 < ne[3]; i3++) {
99
+ for (int i2 = 0; i2 < ne[2]; i2++) {
100
+ for (int i1 = 0; i1 < ne[1]; i1++) {
101
+ for (int i0 = 0; i0 < ne[0]; i0++) {
102
+ ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
103
+ }
104
+ }
105
+ }
106
+ }
107
+ break;
108
+ default:
109
+ assert(false);
110
+ };
111
+
112
+ return result;
113
+ }
114
+
115
+ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
116
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
117
+
118
+ if (plan.work_size > 0) {
119
+ buf.resize(plan.work_size);
120
+ plan.work_data = buf.data();
121
+ }
122
+
123
+ ggml_graph_compute(graph, &plan);
124
+ }
125
+
126
+ int main(int /*argc*/, const char ** /*argv*/) {
127
+ struct ggml_init_params params = {
128
+ /* .mem_size = */ 128*1024*1024,
129
+ /* .mem_buffer = */ NULL,
130
+ /* .no_alloc = */ false,
131
+ };
132
+
133
+ std::vector<uint8_t> work_buffer;
134
+
135
+ struct ggml_context * ctx0 = ggml_init(params);
136
+
137
+ struct ggml_tensor * x;
138
+
139
+ // rope f32
140
+ for (int m = 0; m < 3; ++m) {
141
+ const int ndims = 4;
142
+
143
+ const int64_t n_rot = 128;
144
+ const int64_t ne[4] = { 2*n_rot, 32, 73, 1 };
145
+
146
+ const int n_past_0 = 100;
147
+ const int n_past_2 = 33;
148
+
149
+ struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
150
+ struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
151
+ struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
152
+
153
+ for (int i = 0; i < ne[2]; ++i) {
154
+ ((int32_t *) p0->data)[i] = n_past_0 + i;
155
+ ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
156
+ ((int32_t *) p2->data)[i] = n_past_2 + i;
157
+ }
158
+
159
+ // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
160
+ const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
161
+
162
+ x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
163
+
164
+ // 100, 101, 102, ..., 172
165
+ struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
166
+ // -67, -67, -67, ..., -67
167
+ struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
168
+
169
+ // 33, 34, 35, ..., 105
170
+ struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
171
+
172
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
173
+
174
+ ggml_build_forward_expand(gf, r0);
175
+ ggml_build_forward_expand(gf, r1);
176
+ ggml_build_forward_expand(gf, r2);
177
+
178
+ ggml_graph_compute_helper(work_buffer, gf, 4);
179
+
180
+ // check that r1 and r2 are the same
181
+ {
182
+ double sum0 = 0.0f;
183
+ double sum1 = 0.0f;
184
+ double diff = 0.0f;
185
+
186
+ const float * r1_data = (float *) r1->data;
187
+ const float * r2_data = (float *) r2->data;
188
+
189
+ const int n_elements = ggml_nelements(r1);
190
+
191
+ for (int i = 0; i < n_elements; ++i) {
192
+ sum0 += fabs(r1_data[i]);
193
+ sum1 += fabs(r2_data[i]);
194
+ diff += fabs(r1_data[i] - r2_data[i]);
195
+ //if (fabs(r1_data[i] - r2_data[i]) > 0.0001f) {
196
+ // printf("%d: %f %f\n", i, r1_data[i], r2_data[i]);
197
+ // printf("diff: %f\n", fabs(r1_data[i] - r2_data[i]));
198
+ //}
199
+ }
200
+
201
+ //for (int i = 4096; i < 4096 + 128; ++i) {
202
+ // printf("%f %f\n", r1_data[i], r2_data[i]);
203
+ //}
204
+
205
+ printf("mode: %d\n", mode);
206
+ printf("sum0: %f\n", sum0);
207
+ printf("sum1: %f\n", sum1);
208
+ printf("diff: %f\n", diff);
209
+ printf("rel err: %f\n", diff / sum0);
210
+ printf("rel err: %f\n", diff / sum1);
211
+
212
+ GGML_ASSERT(diff / sum0 < 0.0001f);
213
+ GGML_ASSERT(diff / sum1 < 0.0001f);
214
+ }
215
+ }
216
+
217
+ ggml_free(ctx0);
218
+
219
+ return 0;
220
+ }
221
+
@@ -0,0 +1,301 @@
1
+ #include "ggml.h"
2
+ #include "llama.h"
3
+
4
+ #ifdef NDEBUG
5
+ #undef NDEBUG
6
+ #endif
7
+
8
+ #include <algorithm>
9
+ #include <cmath>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ static void dump(const llama_token_data_array * candidates) {
14
+ for (size_t i = 0; i < candidates->size; i++) {
15
+ printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
16
+ }
17
+ }
18
+
19
+ #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
20
+
21
+ static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
22
+ const size_t n_vocab = probs.size();
23
+ std::vector<llama_token_data> candidates;
24
+ candidates.reserve(n_vocab);
25
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
26
+ const float logit = logf(probs[token_id]);
27
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
28
+ }
29
+
30
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
31
+ llama_sample_softmax(nullptr, &candidates_p);
32
+ DUMP(&candidates_p);
33
+ llama_sample_top_k(nullptr, &candidates_p, k, 1);
34
+ DUMP(&candidates_p);
35
+
36
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
37
+ for (size_t i = 0; i < candidates_p.size; i++) {
38
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
39
+ }
40
+ }
41
+
42
+ static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
43
+ const size_t n_vocab = probs.size();
44
+ std::vector<llama_token_data> candidates;
45
+ candidates.reserve(n_vocab);
46
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
47
+ const float logit = logf(probs[token_id]);
48
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
49
+ }
50
+
51
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
52
+ llama_sample_softmax(nullptr, &candidates_p);
53
+ DUMP(&candidates_p);
54
+ llama_sample_top_p(nullptr, &candidates_p, p, 1);
55
+ DUMP(&candidates_p);
56
+
57
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
58
+ for (size_t i = 0; i < candidates_p.size; i++) {
59
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
60
+ }
61
+ }
62
+
63
+ static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
64
+ const size_t n_vocab = probs.size();
65
+ std::vector<llama_token_data> candidates;
66
+ candidates.reserve(n_vocab);
67
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
68
+ const float logit = logf(probs[token_id]);
69
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
70
+ }
71
+
72
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
73
+ DUMP(&candidates_p);
74
+ llama_sample_tail_free(nullptr, &candidates_p, z, 1);
75
+ DUMP(&candidates_p);
76
+
77
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
78
+ for (size_t i = 0; i < candidates_p.size; i++) {
79
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
80
+ }
81
+ }
82
+
83
+ static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
84
+ const size_t n_vocab = probs.size();
85
+ std::vector<llama_token_data> candidates;
86
+ candidates.reserve(n_vocab);
87
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
88
+ const float logit = logf(probs[token_id]);
89
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
90
+ }
91
+
92
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
93
+ DUMP(&candidates_p);
94
+ llama_sample_min_p(nullptr, &candidates_p, p, 1);
95
+ DUMP(&candidates_p);
96
+ llama_sample_softmax(nullptr, &candidates_p);
97
+
98
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
99
+ for (size_t i = 0; i < candidates_p.size; i++) {
100
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
101
+ }
102
+ }
103
+
104
+ static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
105
+ const size_t n_vocab = probs.size();
106
+ std::vector<llama_token_data> candidates;
107
+ candidates.reserve(n_vocab);
108
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
109
+ const float logit = logf(probs[token_id]);
110
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
111
+ }
112
+
113
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
114
+ DUMP(&candidates_p);
115
+ llama_sample_typical(nullptr, &candidates_p, p, 1);
116
+ DUMP(&candidates_p);
117
+
118
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
119
+ for (size_t i = 0; i < candidates_p.size; i++) {
120
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
121
+ }
122
+ }
123
+
124
+ static void test_repetition_penalties(
125
+ const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
126
+ const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
127
+ ) {
128
+ GGML_ASSERT(probs.size() == expected_probs.size());
129
+
130
+ const size_t n_vocab = probs.size();
131
+ std::vector<llama_token_data> candidates;
132
+ candidates.reserve(n_vocab);
133
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
134
+ const float logit = logf(probs[token_id]);
135
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
136
+ }
137
+
138
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
139
+ llama_sample_softmax(nullptr, &candidates_p);
140
+ DUMP(&candidates_p);
141
+ llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
142
+ llama_sample_softmax(nullptr, &candidates_p);
143
+ DUMP(&candidates_p);
144
+
145
+ GGML_ASSERT(candidates_p.size == expected_probs.size());
146
+ for (size_t i = 0; i < candidates_p.size; i++) {
147
+ GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
148
+ }
149
+ }
150
+
151
+ static void test_sampler_queue(
152
+ const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
153
+ ) {
154
+ std::vector<llama_token_data> candidates;
155
+ candidates.reserve(n_vocab);
156
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
157
+ const float logit = logf(token_id);
158
+ candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
159
+ }
160
+
161
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
162
+
163
+ llama_token min_token_id = 0;
164
+ const llama_token max_token_id = n_vocab-1;
165
+
166
+ for (auto s : samplers_sequence) {
167
+ switch (s){
168
+ case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
169
+ case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
170
+ case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
171
+ case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
172
+ case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
173
+ case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
174
+ default : GGML_ASSERT(false && "Unknown sampler"); break;
175
+ }
176
+
177
+ llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
178
+
179
+ const int size = candidates_p.size;
180
+
181
+ if (s == 'k') {
182
+ const int expected_size = std::min(size, top_k);
183
+ min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
184
+
185
+ GGML_ASSERT(size == expected_size);
186
+ GGML_ASSERT(candidates_p.data[0].id == max_token_id);
187
+ GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
188
+ } else if (s == 'p') {
189
+ const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
190
+ const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
191
+
192
+ min_token_id = n_vocab;
193
+ int expected_size = 0;
194
+ int cumsum = 0;
195
+ do { // do-while because always at least one token is sampled
196
+ min_token_id--;
197
+ expected_size++;
198
+
199
+ cumsum += min_token_id;
200
+ } while (cumsum < softmax_numerator_target);
201
+
202
+ // token 0 has p == 0, need special consideration for cumsum because top_p immediately returns
203
+ if (min_token_id == 1) {
204
+ min_token_id--;
205
+ expected_size += 1;
206
+ }
207
+
208
+ GGML_ASSERT(size == expected_size);
209
+ GGML_ASSERT(candidates_p.data[0].id == max_token_id);
210
+ GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
211
+ } else if (s == 'm') {
212
+ int expected_size = ceilf((1.0f-min_p) * n_vocab);
213
+ expected_size = std::max(expected_size, 1);
214
+ expected_size = std::min(expected_size, size);
215
+
216
+ min_token_id = floorf(min_p * n_vocab);
217
+ min_token_id = std::max(min_token_id, 1);
218
+ min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size));
219
+ min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
220
+
221
+ GGML_ASSERT(size == expected_size);
222
+ GGML_ASSERT(candidates_p.data[0].id == max_token_id);
223
+ GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
224
+ } else {
225
+ GGML_ASSERT(false);
226
+ }
227
+ }
228
+
229
+ printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
230
+ samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
231
+ }
232
+
233
+ int main(void) {
234
+ ggml_time_init();
235
+
236
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
237
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
238
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
239
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
240
+
241
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
242
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
243
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
244
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
245
+
246
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
247
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
248
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.26f);
249
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.49f);
250
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.51f);
251
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
252
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
253
+ test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
254
+
255
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
256
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
257
+ test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
258
+
259
+ test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
260
+ test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
261
+
262
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
263
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
264
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
265
+
266
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
267
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
268
+ test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
269
+
270
+ test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
271
+ test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
272
+ test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
273
+ test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f);
274
+ test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f);
275
+ test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
276
+
277
+ test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
278
+ test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
279
+ test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
280
+ test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
281
+ test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);
282
+
283
+ test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f);
284
+ test_sampler_queue(10000, "km", 100, 0.8f, 0.1f);
285
+ test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f);
286
+ test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f);
287
+ test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f);
288
+ test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f);
289
+ test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f);
290
+
291
+ test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f);
292
+ test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f);
293
+ test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f);
294
+ test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f);
295
+ test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
296
+ test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
297
+
298
+ printf("OK\n");
299
+
300
+ return 0;
301
+ }