@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
// Unit tests for quantization specific functions - quantize, dequantize and dot product
|
|
2
2
|
|
|
3
3
|
#include "ggml.h"
|
|
4
|
+
#include "ggml-cpu.h"
|
|
4
5
|
|
|
5
6
|
#undef NDEBUG
|
|
6
7
|
#include <assert.h>
|
|
@@ -44,26 +45,27 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
|
|
|
44
45
|
}
|
|
45
46
|
|
|
46
47
|
// Total quantization error on test data
|
|
47
|
-
static float total_quantization_error(
|
|
48
|
+
static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
|
|
48
49
|
std::vector<uint8_t> tmp_q(2*test_size);
|
|
49
50
|
std::vector<float> tmp_out(test_size);
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
qfns
|
|
52
|
+
qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
|
|
53
|
+
qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
|
|
53
54
|
return array_rmse(test_data, tmp_out.data(), test_size);
|
|
54
55
|
}
|
|
55
56
|
|
|
56
57
|
// Total quantization error on test data
|
|
57
|
-
static float reference_quantization_error(
|
|
58
|
+
static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
|
|
58
59
|
std::vector<uint8_t> tmp_q(2*test_size);
|
|
59
60
|
std::vector<float> tmp_out(test_size);
|
|
60
61
|
std::vector<float> tmp_out_ref(test_size);
|
|
61
62
|
|
|
62
|
-
|
|
63
|
-
|
|
63
|
+
// FIXME: why is done twice?
|
|
64
|
+
qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
|
|
65
|
+
qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
|
|
64
66
|
|
|
65
|
-
qfns
|
|
66
|
-
qfns
|
|
67
|
+
qfns->from_float_ref(test_data, tmp_q.data(), test_size);
|
|
68
|
+
qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
|
|
67
69
|
|
|
68
70
|
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
|
|
69
71
|
}
|
|
@@ -78,18 +80,18 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
|
|
|
78
80
|
|
|
79
81
|
// Total dot product error
|
|
80
82
|
static float dot_product_error(
|
|
81
|
-
|
|
83
|
+
const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float *test_data2
|
|
82
84
|
) {
|
|
83
85
|
std::vector<uint8_t> tmp_q1(2*test_size);
|
|
84
86
|
std::vector<uint8_t> tmp_q2(2*test_size);
|
|
85
87
|
|
|
86
|
-
auto vdot =
|
|
88
|
+
const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
|
|
87
89
|
|
|
88
|
-
|
|
89
|
-
vdot
|
|
90
|
+
qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
|
|
91
|
+
vdot->from_float(test_data2, tmp_q2.data(), test_size);
|
|
90
92
|
|
|
91
93
|
float result = INFINITY;
|
|
92
|
-
|
|
94
|
+
qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
|
|
93
95
|
|
|
94
96
|
const float dot_ref = dot_product(test_data1, test_data2, test_size);
|
|
95
97
|
|
|
@@ -131,10 +133,11 @@ int main(int argc, char * argv[]) {
|
|
|
131
133
|
|
|
132
134
|
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
|
133
135
|
ggml_type type = (ggml_type) i;
|
|
134
|
-
|
|
136
|
+
const auto * qfns = ggml_get_type_traits(type);
|
|
137
|
+
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
|
|
135
138
|
|
|
136
139
|
// deprecated - skip
|
|
137
|
-
if (qfns
|
|
140
|
+
if (qfns->blck_size == 0) {
|
|
138
141
|
continue;
|
|
139
142
|
}
|
|
140
143
|
|
|
@@ -143,8 +146,8 @@ int main(int argc, char * argv[]) {
|
|
|
143
146
|
printf("Testing %s\n", ggml_type_name((ggml_type) i));
|
|
144
147
|
ggml_quantize_init(ei);
|
|
145
148
|
|
|
146
|
-
if (
|
|
147
|
-
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
|
|
149
|
+
if (qfns_cpu->from_float && qfns->to_float) {
|
|
150
|
+
const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
|
|
148
151
|
const float max_quantization_error =
|
|
149
152
|
type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
|
150
153
|
type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
|
@@ -159,14 +162,14 @@ int main(int argc, char * argv[]) {
|
|
|
159
162
|
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
|
|
160
163
|
}
|
|
161
164
|
|
|
162
|
-
const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
|
|
165
|
+
const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
|
|
163
166
|
failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
|
|
164
167
|
num_failed += failed;
|
|
165
168
|
if (failed || verbose) {
|
|
166
169
|
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
|
|
167
170
|
}
|
|
168
171
|
|
|
169
|
-
const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
|
|
172
|
+
const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
|
|
170
173
|
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
|
|
171
174
|
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
|
172
175
|
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
// Benchmark quantization specific functions on synthetic data
|
|
2
2
|
|
|
3
3
|
#include "ggml.h"
|
|
4
|
+
#include "ggml-cpu.h"
|
|
4
5
|
|
|
5
6
|
#undef NDEBUG
|
|
6
7
|
#include <algorithm>
|
|
7
8
|
#include <assert.h>
|
|
8
9
|
#include <functional>
|
|
9
|
-
#include <inttypes.h>
|
|
10
10
|
#include <math.h>
|
|
11
11
|
#include <memory>
|
|
12
12
|
#include <stdio.h>
|
|
@@ -122,9 +122,10 @@ static void usage(char * argv[]) {
|
|
|
122
122
|
printf(" --type TYPE set test type as");
|
|
123
123
|
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
|
124
124
|
ggml_type type = (ggml_type) i;
|
|
125
|
-
|
|
125
|
+
const auto * qfns = ggml_get_type_traits(type);
|
|
126
|
+
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
|
|
126
127
|
if (ggml_type_name(type) != NULL) {
|
|
127
|
-
if (
|
|
128
|
+
if (qfns_cpu->from_float && qfns->to_float) {
|
|
128
129
|
printf(" %s", ggml_type_name(type));
|
|
129
130
|
}
|
|
130
131
|
}
|
|
@@ -270,12 +271,13 @@ int main(int argc, char * argv[]) {
|
|
|
270
271
|
|
|
271
272
|
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
|
272
273
|
ggml_type type = (ggml_type) i;
|
|
273
|
-
|
|
274
|
+
const auto * qfns = ggml_get_type_traits(type);
|
|
275
|
+
const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
|
|
274
276
|
if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
|
|
275
277
|
continue;
|
|
276
278
|
}
|
|
277
279
|
|
|
278
|
-
if (
|
|
280
|
+
if (qfns_cpu->from_float && qfns->to_float) {
|
|
279
281
|
printf("%s\n", ggml_type_name(type));
|
|
280
282
|
|
|
281
283
|
ggml_quantize_init(type);
|
|
@@ -285,7 +287,7 @@ int main(int argc, char * argv[]) {
|
|
|
285
287
|
for (size_t size : params.test_sizes) {
|
|
286
288
|
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
|
|
287
289
|
auto quantize_fn = [&](void) -> float {
|
|
288
|
-
qfns
|
|
290
|
+
qfns->from_float_ref(test_data1, test_q1, size);
|
|
289
291
|
return test_q1[0];
|
|
290
292
|
};
|
|
291
293
|
size_t quantized_size = ggml_row_size(type, size);
|
|
@@ -299,7 +301,7 @@ int main(int argc, char * argv[]) {
|
|
|
299
301
|
for (size_t size : params.test_sizes) {
|
|
300
302
|
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
|
|
301
303
|
auto quantize_fn = [&](void) -> float {
|
|
302
|
-
|
|
304
|
+
qfns_cpu->from_float(test_data1, test_q1, size);
|
|
303
305
|
return test_q1[0];
|
|
304
306
|
};
|
|
305
307
|
size_t quantized_size = ggml_row_size(type, size);
|
|
@@ -310,11 +312,11 @@ int main(int argc, char * argv[]) {
|
|
|
310
312
|
|
|
311
313
|
if (params.op_dequantize_row_q) {
|
|
312
314
|
printf(" dequantize_row_q\n");
|
|
313
|
-
|
|
315
|
+
qfns_cpu->from_float(test_data1, test_q1, largest);
|
|
314
316
|
for (size_t size : params.test_sizes) {
|
|
315
317
|
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
|
|
316
318
|
auto quantize_fn = [&](void) -> float {
|
|
317
|
-
qfns
|
|
319
|
+
qfns->to_float(test_q1, test_out, size);
|
|
318
320
|
return test_out[0];
|
|
319
321
|
};
|
|
320
322
|
size_t quantized_size = ggml_row_size(type, size);
|
|
@@ -328,8 +330,8 @@ int main(int argc, char * argv[]) {
|
|
|
328
330
|
for (size_t size : params.test_sizes) {
|
|
329
331
|
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
|
|
330
332
|
auto quantize_fn = [&](void) -> float {
|
|
331
|
-
auto vdot =
|
|
332
|
-
vdot
|
|
333
|
+
const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
|
|
334
|
+
vdot->from_float(test_data1, test_q1, size);
|
|
333
335
|
return test_q1[0];
|
|
334
336
|
};
|
|
335
337
|
size_t quantized_size = ggml_row_size(type, size);
|
|
@@ -340,13 +342,13 @@ int main(int argc, char * argv[]) {
|
|
|
340
342
|
|
|
341
343
|
if (params.op_vec_dot_q) {
|
|
342
344
|
printf(" vec_dot_q\n");
|
|
343
|
-
|
|
344
|
-
|
|
345
|
+
qfns_cpu->from_float(test_data1, test_q1, largest);
|
|
346
|
+
qfns_cpu->from_float(test_data2, test_q2, largest);
|
|
345
347
|
for (size_t size : params.test_sizes) {
|
|
346
348
|
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
|
|
347
349
|
auto quantize_fn = [&](void) -> float {
|
|
348
350
|
float result;
|
|
349
|
-
|
|
351
|
+
qfns_cpu->vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
|
|
350
352
|
return result;
|
|
351
353
|
};
|
|
352
354
|
size_t quantized_size = ggml_row_size(type, size);
|
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
#include <string>
|
|
11
11
|
#include <vector>
|
|
12
12
|
|
|
13
|
+
extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers);
|
|
14
|
+
|
|
13
15
|
static void dump(const llama_token_data_array * cur_p) {
|
|
14
16
|
for (size_t i = 0; i < cur_p->size; i++) {
|
|
15
17
|
printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
|
|
@@ -18,181 +20,188 @@ static void dump(const llama_token_data_array * cur_p) {
|
|
|
18
20
|
|
|
19
21
|
#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
struct sampler_tester {
|
|
24
|
+
sampler_tester(size_t n_vocab) {
|
|
25
|
+
cur.reserve(n_vocab);
|
|
26
|
+
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
27
|
+
const float logit = logf(token_id);
|
|
28
|
+
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
29
|
+
}
|
|
26
30
|
|
|
27
|
-
|
|
28
|
-
|
|
31
|
+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
|
|
32
|
+
}
|
|
29
33
|
|
|
30
|
-
std::vector<
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
sampler_tester(const std::vector<float> & probs, const std::vector<float> & probs_expected) : probs_expected(probs_expected) {
|
|
35
|
+
cur.reserve(probs.size());
|
|
36
|
+
for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
|
|
37
|
+
const float logit = logf(probs[token_id]);
|
|
38
|
+
cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
|
|
35
42
|
}
|
|
36
43
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
APPLY(llama_sampler_init_top_k(k), &cur_p);
|
|
41
|
-
DUMP(&cur_p);
|
|
42
|
-
|
|
43
|
-
GGML_ASSERT(cur_p.size == expected_probs.size());
|
|
44
|
-
for (size_t i = 0; i < cur_p.size; i++) {
|
|
45
|
-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
|
|
44
|
+
void apply(llama_sampler * sampler) {
|
|
45
|
+
llama_sampler_apply(sampler, &cur_p);
|
|
46
|
+
llama_sampler_free(sampler);
|
|
46
47
|
}
|
|
47
|
-
}
|
|
48
48
|
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
void check() {
|
|
50
|
+
GGML_ASSERT(cur_p.size == probs_expected.size());
|
|
51
|
+
for (size_t i = 0; i < cur_p.size; i++) {
|
|
52
|
+
GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
llama_token_data_array cur_p;
|
|
57
|
+
|
|
58
|
+
private:
|
|
59
|
+
const std::vector<float> probs_expected;
|
|
51
60
|
|
|
52
61
|
std::vector<llama_token_data> cur;
|
|
53
|
-
|
|
54
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
55
|
-
const float logit = logf(probs[token_id]);
|
|
56
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
57
|
-
}
|
|
62
|
+
};
|
|
58
63
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
}
|
|
64
|
+
static void test_temp(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp) {
|
|
65
|
+
sampler_tester tester(probs, probs_expected);
|
|
66
|
+
|
|
67
|
+
DUMP(&tester.cur_p);
|
|
68
|
+
tester.apply(llama_sampler_init_temp(temp));
|
|
69
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
70
|
+
DUMP(&tester.cur_p);
|
|
71
|
+
|
|
72
|
+
tester.check();
|
|
69
73
|
}
|
|
70
74
|
|
|
71
|
-
static void
|
|
72
|
-
|
|
75
|
+
static void test_temp_ext(const std::vector<float> & probs, const std::vector<float> & probs_expected, float temp, float delta, float exponent) {
|
|
76
|
+
sampler_tester tester(probs, probs_expected);
|
|
73
77
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
79
|
-
}
|
|
78
|
+
DUMP(&tester.cur_p);
|
|
79
|
+
tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
|
|
80
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
81
|
+
DUMP(&tester.cur_p);
|
|
80
82
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
|
|
84
|
-
DUMP(&cur_p);
|
|
83
|
+
tester.check();
|
|
84
|
+
}
|
|
85
85
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
|
|
87
|
+
sampler_tester tester(probs, probs_expected);
|
|
88
|
+
|
|
89
|
+
DUMP(&tester.cur_p);
|
|
90
|
+
tester.apply(llama_sampler_init_top_k(k));
|
|
91
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
92
|
+
DUMP(&tester.cur_p);
|
|
93
|
+
|
|
94
|
+
tester.check();
|
|
90
95
|
}
|
|
91
96
|
|
|
92
|
-
static void
|
|
93
|
-
|
|
97
|
+
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
98
|
+
sampler_tester tester(probs, probs_expected);
|
|
94
99
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
100
|
-
}
|
|
100
|
+
DUMP(&tester.cur_p);
|
|
101
|
+
tester.apply(llama_sampler_init_top_p(p, 1));
|
|
102
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
103
|
+
DUMP(&tester.cur_p);
|
|
101
104
|
|
|
102
|
-
|
|
103
|
-
DUMP(&cur_p);
|
|
104
|
-
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
|
|
105
|
-
DUMP(&cur_p);
|
|
106
|
-
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
107
|
-
|
|
108
|
-
GGML_ASSERT(cur_p.size == expected_probs.size());
|
|
109
|
-
for (size_t i = 0; i < cur_p.size; i++) {
|
|
110
|
-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
111
|
-
}
|
|
105
|
+
tester.check();
|
|
112
106
|
}
|
|
113
107
|
|
|
114
|
-
static void
|
|
115
|
-
|
|
108
|
+
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
109
|
+
sampler_tester tester(probs, probs_expected);
|
|
116
110
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
122
|
-
}
|
|
111
|
+
DUMP(&tester.cur_p);
|
|
112
|
+
tester.apply(llama_sampler_init_min_p(p, 1));
|
|
113
|
+
tester.apply(llama_sampler_init_dist (0));
|
|
114
|
+
DUMP(&tester.cur_p);
|
|
123
115
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
|
|
127
|
-
DUMP(&cur_p);
|
|
116
|
+
tester.check();
|
|
117
|
+
}
|
|
128
118
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
119
|
+
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p, float t) {
|
|
120
|
+
sampler_tester tester(probs, probs_expected);
|
|
121
|
+
|
|
122
|
+
DUMP(&tester.cur_p);
|
|
123
|
+
tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
|
|
124
|
+
DUMP(&tester.cur_p);
|
|
125
|
+
|
|
126
|
+
tester.check();
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
static void test_typical(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
|
|
130
|
+
sampler_tester tester(probs, probs_expected);
|
|
131
|
+
|
|
132
|
+
DUMP(&tester.cur_p);
|
|
133
|
+
tester.apply(llama_sampler_init_typical(p, 1));
|
|
134
|
+
DUMP(&tester.cur_p);
|
|
135
|
+
|
|
136
|
+
tester.check();
|
|
133
137
|
}
|
|
134
138
|
|
|
135
139
|
static void test_penalties(
|
|
136
140
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
137
|
-
const std::vector<float> &
|
|
141
|
+
const std::vector<float> & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
|
|
138
142
|
) {
|
|
139
|
-
GGML_ASSERT(probs.size() ==
|
|
143
|
+
GGML_ASSERT(probs.size() == probs_expected.size());
|
|
144
|
+
|
|
145
|
+
sampler_tester tester(probs, probs_expected);
|
|
140
146
|
|
|
141
147
|
const size_t n_vocab = probs.size();
|
|
148
|
+
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
|
142
149
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
146
|
-
const float logit = logf(probs[token_id]);
|
|
147
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
150
|
+
for (size_t i = 0; i < last_tokens.size(); i++) {
|
|
151
|
+
llama_sampler_accept(sampler, last_tokens[i]);
|
|
148
152
|
}
|
|
149
153
|
|
|
150
|
-
|
|
154
|
+
DUMP(&tester.cur_p);
|
|
155
|
+
tester.apply(sampler);
|
|
156
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
157
|
+
DUMP(&tester.cur_p);
|
|
151
158
|
|
|
152
|
-
|
|
159
|
+
tester.check();
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
static void test_dry(
|
|
163
|
+
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
164
|
+
const std::vector<float> & expected_probs, float dry_multiplier, float dry_base,
|
|
165
|
+
int dry_allowed_length, int dry_penalty_last_n,
|
|
166
|
+
const std::vector<std::vector<llama_token>> & seq_breakers
|
|
167
|
+
) {
|
|
168
|
+
GGML_ASSERT(probs.size() == expected_probs.size());
|
|
169
|
+
|
|
170
|
+
sampler_tester tester(probs, expected_probs);
|
|
171
|
+
|
|
172
|
+
auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
|
|
153
173
|
|
|
154
174
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
|
155
175
|
llama_sampler_accept(sampler, last_tokens[i]);
|
|
156
176
|
}
|
|
157
177
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
GGML_ASSERT(cur_p.size == expected_probs.size());
|
|
165
|
-
for (size_t i = 0; i < cur_p.size; i++) {
|
|
166
|
-
GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
167
|
-
}
|
|
178
|
+
DUMP(&tester.cur_p);
|
|
179
|
+
tester.apply(sampler);
|
|
180
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
181
|
+
DUMP(&tester.cur_p);
|
|
182
|
+
tester.check();
|
|
168
183
|
}
|
|
169
184
|
|
|
170
185
|
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
|
171
186
|
) {
|
|
172
|
-
|
|
173
|
-
cur.reserve(n_vocab);
|
|
174
|
-
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
175
|
-
const float logit = logf(token_id);
|
|
176
|
-
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
187
|
+
sampler_tester tester(n_vocab);
|
|
180
188
|
|
|
181
189
|
llama_token min_token_id = 0;
|
|
182
190
|
const llama_token max_token_id = n_vocab-1;
|
|
183
191
|
|
|
184
192
|
for (auto s : samplers_sequence) {
|
|
185
193
|
switch (s){
|
|
186
|
-
case 'k':
|
|
187
|
-
case 'f': GGML_ABORT("tail_free test not implemented");
|
|
194
|
+
case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
|
|
188
195
|
case 'y': GGML_ABORT("typical test not implemented");
|
|
189
|
-
case 'p':
|
|
190
|
-
case 'm':
|
|
196
|
+
case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
|
|
197
|
+
case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
|
|
191
198
|
case 't': GGML_ABORT("temperature test not implemented");
|
|
192
199
|
default : GGML_ABORT("Unknown sampler");
|
|
193
200
|
}
|
|
194
201
|
|
|
195
|
-
|
|
202
|
+
tester.apply(llama_sampler_init_dist(0));
|
|
203
|
+
|
|
204
|
+
auto & cur_p = tester.cur_p;
|
|
196
205
|
|
|
197
206
|
const int size = cur_p.size;
|
|
198
207
|
|
|
@@ -263,7 +272,7 @@ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vec
|
|
|
263
272
|
}
|
|
264
273
|
const int64_t t_end = ggml_time_us();
|
|
265
274
|
llama_sampler_free(cnstr);
|
|
266
|
-
printf("%-
|
|
275
|
+
printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
|
|
267
276
|
}
|
|
268
277
|
|
|
269
278
|
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
|
|
@@ -279,26 +288,31 @@ static void test_perf() {
|
|
|
279
288
|
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
|
280
289
|
}
|
|
281
290
|
|
|
282
|
-
BENCH(llama_sampler_init_top_k
|
|
283
|
-
BENCH(llama_sampler_init_top_p
|
|
284
|
-
BENCH(llama_sampler_init_min_p
|
|
285
|
-
BENCH(
|
|
286
|
-
BENCH(
|
|
287
|
-
BENCH(llama_sampler_init_softmax (), data, 32);
|
|
291
|
+
BENCH(llama_sampler_init_top_k (40), data, 32);
|
|
292
|
+
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
|
293
|
+
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
|
294
|
+
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
|
|
295
|
+
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
|
288
296
|
}
|
|
289
297
|
|
|
290
298
|
int main(void) {
|
|
291
299
|
ggml_time_init();
|
|
292
300
|
|
|
293
|
-
|
|
294
|
-
|
|
301
|
+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
|
302
|
+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
|
|
303
|
+
|
|
304
|
+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
|
|
305
|
+
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
|
|
306
|
+
|
|
307
|
+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
|
308
|
+
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
|
295
309
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
|
296
310
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
|
297
311
|
|
|
298
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {
|
|
299
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.
|
|
300
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.
|
|
301
|
-
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
|
|
312
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
|
|
313
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
|
|
314
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
|
|
315
|
+
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
|
302
316
|
|
|
303
317
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
|
|
304
318
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
|
|
@@ -309,9 +323,13 @@ int main(void) {
|
|
|
309
323
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
|
310
324
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
|
311
325
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
326
|
+
printf("XTC should:\n");
|
|
327
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.1f}, 0.99f, 0.09f);
|
|
328
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.2f, 0.1f}, 0.99f, 0.19f);
|
|
329
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.3f, 0.2f, 0.1f}, 0.99f, 0.29f);
|
|
330
|
+
|
|
331
|
+
printf("XTC should not:\n");
|
|
332
|
+
test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f);
|
|
315
333
|
|
|
316
334
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
|
317
335
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
|
@@ -324,6 +342,13 @@ int main(void) {
|
|
|
324
342
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
|
325
343
|
test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
|
326
344
|
|
|
345
|
+
|
|
346
|
+
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
|
|
347
|
+
test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
|
|
348
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
|
|
349
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
|
350
|
+
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
|
351
|
+
|
|
327
352
|
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
|
328
353
|
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
|
329
354
|
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|