@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  // Unit tests for quantization specific functions - quantize, dequantize and dot product
2
2
 
3
3
  #include "ggml.h"
4
+ #include "ggml-cpu.h"
4
5
 
5
6
  #undef NDEBUG
6
7
  #include <assert.h>
@@ -44,26 +45,27 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
44
45
  }
45
46
 
46
47
  // Total quantization error on test data
47
- static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
48
+ static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
48
49
  std::vector<uint8_t> tmp_q(2*test_size);
49
50
  std::vector<float> tmp_out(test_size);
50
51
 
51
- qfns.from_float(test_data, tmp_q.data(), test_size);
52
- qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
52
+ qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
53
+ qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
53
54
  return array_rmse(test_data, tmp_out.data(), test_size);
54
55
  }
55
56
 
56
57
  // Total quantization error on test data
57
- static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
58
+ static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
58
59
  std::vector<uint8_t> tmp_q(2*test_size);
59
60
  std::vector<float> tmp_out(test_size);
60
61
  std::vector<float> tmp_out_ref(test_size);
61
62
 
62
- qfns.from_float(test_data, tmp_q.data(), test_size);
63
- qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
63
+ // FIXME: why is done twice?
64
+ qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
65
+ qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
64
66
 
65
- qfns.from_float_ref(test_data, tmp_q.data(), test_size);
66
- qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
67
+ qfns->from_float_ref(test_data, tmp_q.data(), test_size);
68
+ qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
67
69
 
68
70
  return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
69
71
  }
@@ -78,18 +80,18 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
78
80
 
79
81
  // Total dot product error
80
82
  static float dot_product_error(
81
- ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
83
+ const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float *test_data2
82
84
  ) {
83
85
  std::vector<uint8_t> tmp_q1(2*test_size);
84
86
  std::vector<uint8_t> tmp_q2(2*test_size);
85
87
 
86
- auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
88
+ const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
87
89
 
88
- qfns.from_float(test_data1, tmp_q1.data(), test_size);
89
- vdot.from_float(test_data2, tmp_q2.data(), test_size);
90
+ qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
91
+ vdot->from_float(test_data2, tmp_q2.data(), test_size);
90
92
 
91
93
  float result = INFINITY;
92
- qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
94
+ qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
93
95
 
94
96
  const float dot_ref = dot_product(test_data1, test_data2, test_size);
95
97
 
@@ -131,10 +133,11 @@ int main(int argc, char * argv[]) {
131
133
 
132
134
  for (int i = 0; i < GGML_TYPE_COUNT; i++) {
133
135
  ggml_type type = (ggml_type) i;
134
- ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
136
+ const auto * qfns = ggml_get_type_traits(type);
137
+ const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
135
138
 
136
139
  // deprecated - skip
137
- if (qfns.blck_size == 0) {
140
+ if (qfns->blck_size == 0) {
138
141
  continue;
139
142
  }
140
143
 
@@ -143,8 +146,8 @@ int main(int argc, char * argv[]) {
143
146
  printf("Testing %s\n", ggml_type_name((ggml_type) i));
144
147
  ggml_quantize_init(ei);
145
148
 
146
- if (qfns.from_float && qfns.to_float) {
147
- const float total_error = total_quantization_error(qfns, test_size, test_data.data());
149
+ if (qfns_cpu->from_float && qfns->to_float) {
150
+ const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
148
151
  const float max_quantization_error =
149
152
  type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
150
153
  type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
@@ -159,14 +162,14 @@ int main(int argc, char * argv[]) {
159
162
  printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
160
163
  }
161
164
 
162
- const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
165
+ const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
163
166
  failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
164
167
  num_failed += failed;
165
168
  if (failed || verbose) {
166
169
  printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
167
170
  }
168
171
 
169
- const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
172
+ const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
170
173
  const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
171
174
  type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
172
175
  ? MAX_DOT_PRODUCT_ERROR_LOWBIT
@@ -1,12 +1,12 @@
1
1
  // Benchmark quantization specific functions on synthetic data
2
2
 
3
3
  #include "ggml.h"
4
+ #include "ggml-cpu.h"
4
5
 
5
6
  #undef NDEBUG
6
7
  #include <algorithm>
7
8
  #include <assert.h>
8
9
  #include <functional>
9
- #include <inttypes.h>
10
10
  #include <math.h>
11
11
  #include <memory>
12
12
  #include <stdio.h>
@@ -122,9 +122,10 @@ static void usage(char * argv[]) {
122
122
  printf(" --type TYPE set test type as");
123
123
  for (int i = 0; i < GGML_TYPE_COUNT; i++) {
124
124
  ggml_type type = (ggml_type) i;
125
- ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
125
+ const auto * qfns = ggml_get_type_traits(type);
126
+ const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
126
127
  if (ggml_type_name(type) != NULL) {
127
- if (qfns.from_float && qfns.to_float) {
128
+ if (qfns_cpu->from_float && qfns->to_float) {
128
129
  printf(" %s", ggml_type_name(type));
129
130
  }
130
131
  }
@@ -270,12 +271,13 @@ int main(int argc, char * argv[]) {
270
271
 
271
272
  for (int i = 0; i < GGML_TYPE_COUNT; i++) {
272
273
  ggml_type type = (ggml_type) i;
273
- ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
274
+ const auto * qfns = ggml_get_type_traits(type);
275
+ const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
274
276
  if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
275
277
  continue;
276
278
  }
277
279
 
278
- if (qfns.from_float && qfns.to_float) {
280
+ if (qfns_cpu->from_float && qfns->to_float) {
279
281
  printf("%s\n", ggml_type_name(type));
280
282
 
281
283
  ggml_quantize_init(type);
@@ -285,7 +287,7 @@ int main(int argc, char * argv[]) {
285
287
  for (size_t size : params.test_sizes) {
286
288
  printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
287
289
  auto quantize_fn = [&](void) -> float {
288
- qfns.from_float_ref(test_data1, test_q1, size);
290
+ qfns->from_float_ref(test_data1, test_q1, size);
289
291
  return test_q1[0];
290
292
  };
291
293
  size_t quantized_size = ggml_row_size(type, size);
@@ -299,7 +301,7 @@ int main(int argc, char * argv[]) {
299
301
  for (size_t size : params.test_sizes) {
300
302
  printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
301
303
  auto quantize_fn = [&](void) -> float {
302
- qfns.from_float(test_data1, test_q1, size);
304
+ qfns_cpu->from_float(test_data1, test_q1, size);
303
305
  return test_q1[0];
304
306
  };
305
307
  size_t quantized_size = ggml_row_size(type, size);
@@ -310,11 +312,11 @@ int main(int argc, char * argv[]) {
310
312
 
311
313
  if (params.op_dequantize_row_q) {
312
314
  printf(" dequantize_row_q\n");
313
- qfns.from_float(test_data1, test_q1, largest);
315
+ qfns_cpu->from_float(test_data1, test_q1, largest);
314
316
  for (size_t size : params.test_sizes) {
315
317
  printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
316
318
  auto quantize_fn = [&](void) -> float {
317
- qfns.to_float(test_q1, test_out, size);
319
+ qfns->to_float(test_q1, test_out, size);
318
320
  return test_out[0];
319
321
  };
320
322
  size_t quantized_size = ggml_row_size(type, size);
@@ -328,8 +330,8 @@ int main(int argc, char * argv[]) {
328
330
  for (size_t size : params.test_sizes) {
329
331
  printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
330
332
  auto quantize_fn = [&](void) -> float {
331
- auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
332
- vdot.from_float(test_data1, test_q1, size);
333
+ const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
334
+ vdot->from_float(test_data1, test_q1, size);
333
335
  return test_q1[0];
334
336
  };
335
337
  size_t quantized_size = ggml_row_size(type, size);
@@ -340,13 +342,13 @@ int main(int argc, char * argv[]) {
340
342
 
341
343
  if (params.op_vec_dot_q) {
342
344
  printf(" vec_dot_q\n");
343
- qfns.from_float(test_data1, test_q1, largest);
344
- qfns.from_float(test_data2, test_q2, largest);
345
+ qfns_cpu->from_float(test_data1, test_q1, largest);
346
+ qfns_cpu->from_float(test_data2, test_q2, largest);
345
347
  for (size_t size : params.test_sizes) {
346
348
  printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
347
349
  auto quantize_fn = [&](void) -> float {
348
350
  float result;
349
- qfns.vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
351
+ qfns_cpu->vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
350
352
  return result;
351
353
  };
352
354
  size_t quantized_size = ggml_row_size(type, size);
@@ -1,4 +1,5 @@
1
1
  #include "ggml.h"
2
+ #include "ggml-cpu.h"
2
3
 
3
4
  #include <cmath>
4
5
  #include <cstdio>
@@ -10,6 +10,8 @@
10
10
  #include <string>
11
11
  #include <vector>
12
12
 
13
+ extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers);
14
+
13
15
  static void dump(const llama_token_data_array * cur_p) {
14
16
  for (size_t i = 0; i < cur_p->size; i++) {
15
17
  printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
@@ -18,181 +20,188 @@ static void dump(const llama_token_data_array * cur_p) {
18
20
 
19
21
  #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
20
22
 
21
- #define APPLY(__cnstr, __cur_p) do { \
22
- auto * cnstr = (__cnstr); \
23
- llama_sampler_apply(cnstr, (__cur_p)); \
24
- llama_sampler_free(cnstr); \
25
- } while(0)
23
+ struct sampler_tester {
24
+ sampler_tester(size_t n_vocab) {
25
+ cur.reserve(n_vocab);
26
+ for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
27
+ const float logit = logf(token_id);
28
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
29
+ }
26
30
 
27
- static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
28
- const size_t n_vocab = probs.size();
31
+ cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
32
+ }
29
33
 
30
- std::vector<llama_token_data> cur;
31
- cur.reserve(n_vocab);
32
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
33
- const float logit = logf(probs[token_id]);
34
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
34
+ sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
35
+ cur.reserve(probs.size());
36
+ for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
37
+ const float logit = logf(probs[token_id]);
38
+ cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
39
+ }
40
+
41
+ cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
35
42
  }
36
43
 
37
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
38
- APPLY(llama_sampler_init_softmax(), &cur_p);
39
- DUMP(&cur_p);
40
- APPLY(llama_sampler_init_top_k(k), &cur_p);
41
- DUMP(&cur_p);
42
-
43
- GGML_ASSERT(cur_p.size == expected_probs.size());
44
- for (size_t i = 0; i < cur_p.size; i++) {
45
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
44
+ void apply(llama_sampler * sampler) {
45
+ llama_sampler_apply(sampler, &cur_p);
46
+ llama_sampler_free(sampler);
46
47
  }
47
- }
48
48
 
49
- static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
50
- const size_t n_vocab = probs.size();
49
+ void check() {
50
+ GGML_ASSERT(cur_p.size == probs_expected.size());
51
+ for (size_t i = 0; i < cur_p.size; i++) {
52
+ GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
53
+ }
54
+ }
55
+
56
+ llama_token_data_array cur_p;
57
+
58
+ private:
59
+ const std::vector<float> probs_expected;
51
60
 
52
61
  std::vector<llama_token_data> cur;
53
- cur.reserve(n_vocab);
54
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
55
- const float logit = logf(probs[token_id]);
56
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
57
- }
62
+ };
58
63
 
59
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
60
- APPLY(llama_sampler_init_softmax(), &cur_p);
61
- DUMP(&cur_p);
62
- APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
63
- DUMP(&cur_p);
64
-
65
- GGML_ASSERT(cur_p.size == expected_probs.size());
66
- for (size_t i = 0; i < cur_p.size; i++) {
67
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
68
- }
64
+ static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
65
+ sampler_tester tester(probs, probs_expected);
66
+
67
+ DUMP(&tester.cur_p);
68
+ tester.apply(llama_sampler_init_temp(temp));
69
+ tester.apply(llama_sampler_init_dist(0));
70
+ DUMP(&tester.cur_p);
71
+
72
+ tester.check();
69
73
  }
70
74
 
71
- static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
72
- const size_t n_vocab = probs.size();
75
+ static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
76
+ sampler_tester tester(probs, probs_expected);
73
77
 
74
- std::vector<llama_token_data> cur;
75
- cur.reserve(n_vocab);
76
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
77
- const float logit = logf(probs[token_id]);
78
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
79
- }
78
+ DUMP(&tester.cur_p);
79
+ tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
80
+ tester.apply(llama_sampler_init_dist (0));
81
+ DUMP(&tester.cur_p);
80
82
 
81
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
82
- DUMP(&cur_p);
83
- APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
84
- DUMP(&cur_p);
83
+ tester.check();
84
+ }
85
85
 
86
- GGML_ASSERT(cur_p.size == expected_probs.size());
87
- for (size_t i = 0; i < cur_p.size; i++) {
88
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
89
- }
86
+ static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
87
+ sampler_tester tester(probs, probs_expected);
88
+
89
+ DUMP(&tester.cur_p);
90
+ tester.apply(llama_sampler_init_top_k(k));
91
+ tester.apply(llama_sampler_init_dist (0));
92
+ DUMP(&tester.cur_p);
93
+
94
+ tester.check();
90
95
  }
91
96
 
92
- static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
93
- const size_t n_vocab = probs.size();
97
+ static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
98
+ sampler_tester tester(probs, probs_expected);
94
99
 
95
- std::vector<llama_token_data> cur;
96
- cur.reserve(n_vocab);
97
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
98
- const float logit = logf(probs[token_id]);
99
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
100
- }
100
+ DUMP(&tester.cur_p);
101
+ tester.apply(llama_sampler_init_top_p(p, 1));
102
+ tester.apply(llama_sampler_init_dist (0));
103
+ DUMP(&tester.cur_p);
101
104
 
102
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
103
- DUMP(&cur_p);
104
- APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
105
- DUMP(&cur_p);
106
- APPLY(llama_sampler_init_softmax(), &cur_p);
107
-
108
- GGML_ASSERT(cur_p.size == expected_probs.size());
109
- for (size_t i = 0; i < cur_p.size; i++) {
110
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
111
- }
105
+ tester.check();
112
106
  }
113
107
 
114
- static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
115
- const size_t n_vocab = probs.size();
108
+ static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
109
+ sampler_tester tester(probs, probs_expected);
116
110
 
117
- std::vector<llama_token_data> cur;
118
- cur.reserve(n_vocab);
119
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
120
- const float logit = logf(probs[token_id]);
121
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
122
- }
111
+ DUMP(&tester.cur_p);
112
+ tester.apply(llama_sampler_init_min_p(p, 1));
113
+ tester.apply(llama_sampler_init_dist (0));
114
+ DUMP(&tester.cur_p);
123
115
 
124
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
125
- DUMP(&cur_p);
126
- APPLY(llama_sampler_init_typical(p, 1), &cur_p);
127
- DUMP(&cur_p);
116
+ tester.check();
117
+ }
128
118
 
129
- GGML_ASSERT(cur_p.size == expected_probs.size());
130
- for (size_t i = 0; i < cur_p.size; i++) {
131
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
132
- }
119
+ static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
120
+ sampler_tester tester(probs, probs_expected);
121
+
122
+ DUMP(&tester.cur_p);
123
+ tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
124
+ DUMP(&tester.cur_p);
125
+
126
+ tester.check();
127
+ }
128
+
129
+ static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
130
+ sampler_tester tester(probs, probs_expected);
131
+
132
+ DUMP(&tester.cur_p);
133
+ tester.apply(llama_sampler_init_typical(p, 1));
134
+ DUMP(&tester.cur_p);
135
+
136
+ tester.check();
133
137
  }
134
138
 
135
139
  static void test_penalties(
136
140
  const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
137
- const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
141
+ const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
138
142
  ) {
139
- GGML_ASSERT(probs.size() == expected_probs.size());
143
+ GGML_ASSERT(probs.size() == probs_expected.size());
144
+
145
+ sampler_tester tester(probs, probs_expected);
140
146
 
141
147
  const size_t n_vocab = probs.size();
148
+ auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
142
149
 
143
- std::vector<llama_token_data> cur;
144
- cur.reserve(n_vocab);
145
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
146
- const float logit = logf(probs[token_id]);
147
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
150
+ for (size_t i = 0; i < last_tokens.size(); i++) {
151
+ llama_sampler_accept(sampler, last_tokens[i]);
148
152
  }
149
153
 
150
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
154
+ DUMP(&tester.cur_p);
155
+ tester.apply(sampler);
156
+ tester.apply(llama_sampler_init_dist(0));
157
+ DUMP(&tester.cur_p);
151
158
 
152
- auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
159
+ tester.check();
160
+ }
161
+
162
+ static void test_dry(
163
+ const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
164
+ const std::vector<float> & expected_probs, float dry_multiplier, float dry_base,
165
+ int dry_allowed_length, int dry_penalty_last_n,
166
+ const std::vector<std::vector<llama_token>> & seq_breakers
167
+ ) {
168
+ GGML_ASSERT(probs.size() == expected_probs.size());
169
+
170
+ sampler_tester tester(probs, expected_probs);
171
+
172
+ auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
153
173
 
154
174
  for (size_t i = 0; i < last_tokens.size(); i++) {
155
175
  llama_sampler_accept(sampler, last_tokens[i]);
156
176
  }
157
177
 
158
- APPLY(llama_sampler_init_softmax(), &cur_p);
159
- DUMP(&cur_p);
160
- APPLY(sampler, &cur_p);
161
- APPLY(llama_sampler_init_softmax(), &cur_p);
162
- DUMP(&cur_p);
163
-
164
- GGML_ASSERT(cur_p.size == expected_probs.size());
165
- for (size_t i = 0; i < cur_p.size; i++) {
166
- GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
167
- }
178
+ DUMP(&tester.cur_p);
179
+ tester.apply(sampler);
180
+ tester.apply(llama_sampler_init_dist(0));
181
+ DUMP(&tester.cur_p);
182
+ tester.check();
168
183
  }
169
184
 
170
185
  static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
171
186
  ) {
172
- std::vector<llama_token_data> cur;
173
- cur.reserve(n_vocab);
174
- for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
175
- const float logit = logf(token_id);
176
- cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
177
- }
178
-
179
- llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
187
+ sampler_tester tester(n_vocab);
180
188
 
181
189
  llama_token min_token_id = 0;
182
190
  const llama_token max_token_id = n_vocab-1;
183
191
 
184
192
  for (auto s : samplers_sequence) {
185
193
  switch (s){
186
- case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
187
- case 'f': GGML_ABORT("tail_free test not implemented");
194
+ case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
188
195
  case 'y': GGML_ABORT("typical test not implemented");
189
- case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
190
- case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
196
+ case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
197
+ case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
191
198
  case 't': GGML_ABORT("temperature test not implemented");
192
199
  default : GGML_ABORT("Unknown sampler");
193
200
  }
194
201
 
195
- APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
202
+ tester.apply(llama_sampler_init_dist(0));
203
+
204
+ auto & cur_p = tester.cur_p;
196
205
 
197
206
  const int size = cur_p.size;
198
207
 
@@ -263,7 +272,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
263
272
  }
264
273
  const int64_t t_end = ggml_time_us();
265
274
  llama_sampler_free(cnstr);
266
- printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
275
+ printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
267
276
  }
268
277
 
269
278
  #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
@@ -279,26 +288,31 @@ static void test_perf() {
279
288
  data.emplace_back(llama_token_data{i, logit, 0.0f});
280
289
  }
281
290
 
282
- BENCH(llama_sampler_init_top_k (40), data, 32);
283
- BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
284
- BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
285
- BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
286
- BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
287
- BENCH(llama_sampler_init_softmax (), data, 32);
291
+ BENCH(llama_sampler_init_top_k (40), data, 32);
292
+ BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
293
+ BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
294
+ BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
295
+ BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
288
296
  }
289
297
 
290
298
  int main(void) {
291
299
  ggml_time_init();
292
300
 
293
- test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
294
- test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
301
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
302
+ test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
303
+
304
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
305
+ test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
306
+
307
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
308
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
295
309
  test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
296
310
  test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
297
311
 
298
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
299
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
300
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
301
- test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
312
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
313
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
314
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
315
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
302
316
 
303
317
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
304
318
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
@@ -309,9 +323,13 @@ int main(void) {
309
323
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
310
324
  test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
311
325
 
312
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
313
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
314
- test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
326
+ printf("XTC should:\n");
327
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
328
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
329
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
330
+
331
+ printf("XTC should not:\n");
332
+ test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
315
333
 
316
334
  test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
317
335
  test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
@@ -324,6 +342,13 @@ int main(void) {
324
342
  test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
325
343
  test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
326
344
 
345
+
346
+ test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
347
+ test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
348
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
349
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
350
+ test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
351
+
327
352
  test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
328
353
  test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
329
354
  test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);