@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
#include "arg.h"
|
|
1
2
|
#include "common.h"
|
|
3
|
+
#include "log.h"
|
|
2
4
|
#include "llama.h"
|
|
3
5
|
|
|
6
|
+
#include <algorithm>
|
|
7
|
+
#include <array>
|
|
8
|
+
#include <atomic>
|
|
4
9
|
#include <cmath>
|
|
5
10
|
#include <cstdio>
|
|
6
11
|
#include <cstring>
|
|
7
12
|
#include <ctime>
|
|
13
|
+
#include <fstream>
|
|
14
|
+
#include <mutex>
|
|
15
|
+
#include <random>
|
|
8
16
|
#include <sstream>
|
|
9
17
|
#include <thread>
|
|
10
|
-
#include <mutex>
|
|
11
|
-
#include <atomic>
|
|
12
18
|
#include <vector>
|
|
13
|
-
#include <array>
|
|
14
|
-
#include <fstream>
|
|
15
|
-
#include <sstream>
|
|
16
19
|
|
|
17
20
|
#if defined(_MSC_VER)
|
|
18
21
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
@@ -31,55 +34,6 @@ struct results_log_softmax {
|
|
|
31
34
|
float prob;
|
|
32
35
|
};
|
|
33
36
|
|
|
34
|
-
static void write_logfile(
|
|
35
|
-
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
|
36
|
-
const struct results_perplexity & results
|
|
37
|
-
) {
|
|
38
|
-
if (params.logdir.empty()) {
|
|
39
|
-
return;
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
if (params.hellaswag) {
|
|
43
|
-
fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
|
|
44
|
-
return;
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
const std::string timestamp = string_get_sortable_timestamp();
|
|
48
|
-
|
|
49
|
-
const bool success = fs_create_directory_with_parents(params.logdir);
|
|
50
|
-
if (!success) {
|
|
51
|
-
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
|
|
52
|
-
__func__, params.logdir.c_str());
|
|
53
|
-
return;
|
|
54
|
-
}
|
|
55
|
-
|
|
56
|
-
const std::string logfile_path = params.logdir + timestamp + ".yml";
|
|
57
|
-
FILE * logfile = fopen(logfile_path.c_str(), "w");
|
|
58
|
-
|
|
59
|
-
if (logfile == NULL) {
|
|
60
|
-
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
|
|
61
|
-
return;
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
fprintf(logfile, "binary: main\n");
|
|
65
|
-
char model_desc[128];
|
|
66
|
-
llama_model_desc(model, model_desc, sizeof(model_desc));
|
|
67
|
-
yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc);
|
|
68
|
-
|
|
69
|
-
fprintf(logfile, "\n");
|
|
70
|
-
fprintf(logfile, "######################\n");
|
|
71
|
-
fprintf(logfile, "# Perplexity Results #\n");
|
|
72
|
-
fprintf(logfile, "######################\n");
|
|
73
|
-
fprintf(logfile, "\n");
|
|
74
|
-
|
|
75
|
-
yaml_dump_vector_float(logfile, "logits", results.logits);
|
|
76
|
-
fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
|
|
77
|
-
yaml_dump_vector_float(logfile, "probs", results.probs);
|
|
78
|
-
|
|
79
|
-
llama_dump_timing_info_yaml(logfile, ctx);
|
|
80
|
-
fclose(logfile);
|
|
81
|
-
}
|
|
82
|
-
|
|
83
37
|
static std::vector<float> softmax(const std::vector<float>& logits) {
|
|
84
38
|
std::vector<float> probs(logits.size());
|
|
85
39
|
float max_logit = logits[0];
|
|
@@ -166,7 +120,7 @@ static void process_logits(
|
|
|
166
120
|
break;
|
|
167
121
|
}
|
|
168
122
|
lock.unlock();
|
|
169
|
-
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
|
123
|
+
const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]);
|
|
170
124
|
const double v = -results.log_softmax;
|
|
171
125
|
local_nll += v;
|
|
172
126
|
local_nll2 += v*v;
|
|
@@ -200,7 +154,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
|
|
|
200
154
|
break;
|
|
201
155
|
}
|
|
202
156
|
lock.unlock();
|
|
203
|
-
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
|
157
|
+
const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
|
204
158
|
local_nll += v;
|
|
205
159
|
local_nll2 += v*v;
|
|
206
160
|
}
|
|
@@ -278,7 +232,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
|
|
|
278
232
|
kld.sum_kld += sum;
|
|
279
233
|
kld.sum_kld2 += sum*sum;
|
|
280
234
|
++kld.count;
|
|
281
|
-
if (imax == imax_base)
|
|
235
|
+
if (imax == imax_base) {
|
|
236
|
+
++kld.n_same_top;
|
|
237
|
+
}
|
|
282
238
|
|
|
283
239
|
const float p_base = expf(-nll_base);
|
|
284
240
|
const float p = expf(-nll);
|
|
@@ -320,7 +276,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|
|
320
276
|
break;
|
|
321
277
|
}
|
|
322
278
|
lock.unlock();
|
|
323
|
-
std::pair<double, float> v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
|
279
|
+
std::pair<double, float> v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
|
324
280
|
kld_values[i] = (float)v.first;
|
|
325
281
|
p_diff_values[i] = v.second;
|
|
326
282
|
}
|
|
@@ -334,25 +290,25 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|
|
334
290
|
}
|
|
335
291
|
}
|
|
336
292
|
|
|
337
|
-
static results_perplexity perplexity_v2(llama_context * ctx, const
|
|
293
|
+
static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) {
|
|
338
294
|
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
|
339
295
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
|
340
296
|
// Output: `perplexity: 13.5106 [114/114]`
|
|
341
297
|
// BOS tokens will be added for each chunk before eval
|
|
342
298
|
|
|
343
|
-
const bool add_bos =
|
|
344
|
-
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx))
|
|
299
|
+
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
|
300
|
+
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
|
345
301
|
|
|
346
|
-
|
|
302
|
+
LOG_INF("%s: tokenizing the input ..\n", __func__);
|
|
347
303
|
|
|
348
|
-
std::vector<llama_token> tokens =
|
|
304
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
|
|
349
305
|
|
|
350
306
|
const int n_ctx = llama_n_ctx(ctx);
|
|
351
307
|
|
|
352
308
|
if (int(tokens.size()) < 2*n_ctx) {
|
|
353
|
-
|
|
309
|
+
LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
|
|
354
310
|
n_ctx);
|
|
355
|
-
|
|
311
|
+
LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
|
|
356
312
|
return {std::move(tokens), 0., {}, {}};
|
|
357
313
|
}
|
|
358
314
|
|
|
@@ -363,16 +319,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
363
319
|
prob_history.resize(tokens.size());
|
|
364
320
|
|
|
365
321
|
if (params.ppl_stride <= 0) {
|
|
366
|
-
|
|
322
|
+
LOG_ERR("%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
|
367
323
|
return {tokens, -1, logit_history, prob_history};
|
|
368
324
|
}
|
|
369
325
|
|
|
370
326
|
const int calc_chunk = n_ctx;
|
|
371
327
|
|
|
372
|
-
|
|
328
|
+
LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
|
|
373
329
|
|
|
374
330
|
if (int(tokens.size()) <= calc_chunk) {
|
|
375
|
-
|
|
331
|
+
LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
|
|
376
332
|
tokens.size(), n_ctx, params.ppl_stride);
|
|
377
333
|
return {tokens, -1, logit_history, prob_history};
|
|
378
334
|
}
|
|
@@ -380,20 +336,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
380
336
|
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
|
|
381
337
|
|
|
382
338
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
|
383
|
-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
384
339
|
const int n_batch = params.n_batch;
|
|
385
340
|
|
|
341
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
342
|
+
|
|
386
343
|
int count = 0;
|
|
387
344
|
double nll = 0.0;
|
|
388
345
|
|
|
389
|
-
|
|
346
|
+
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
|
390
347
|
|
|
391
348
|
for (int i = 0; i < n_chunk; ++i) {
|
|
392
349
|
const int start = i * params.ppl_stride;
|
|
393
350
|
const int end = start + calc_chunk;
|
|
394
351
|
|
|
395
352
|
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
|
396
|
-
//
|
|
353
|
+
//LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
|
397
354
|
|
|
398
355
|
std::vector<float> logits;
|
|
399
356
|
|
|
@@ -402,14 +359,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
402
359
|
// clear the KV cache
|
|
403
360
|
llama_kv_cache_clear(ctx);
|
|
404
361
|
|
|
362
|
+
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
|
363
|
+
|
|
405
364
|
for (int j = 0; j < num_batches; ++j) {
|
|
406
365
|
const int batch_start = start + j * n_batch;
|
|
407
366
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
408
367
|
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
368
|
+
common_batch_clear(batch);
|
|
369
|
+
for (int i = 0; i < batch_size; i++) {
|
|
370
|
+
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
|
374
|
+
if (llama_decode(ctx, batch)) {
|
|
375
|
+
//LOG_ERR("%s : failed to eval\n", __func__);
|
|
376
|
+
llama_batch_free(batch);
|
|
413
377
|
return {tokens, -1, logit_history, prob_history};
|
|
414
378
|
}
|
|
415
379
|
|
|
@@ -421,34 +385,35 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
421
385
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
422
386
|
}
|
|
423
387
|
|
|
424
|
-
const auto batch_logits = llama_get_logits(ctx);
|
|
425
|
-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
|
388
|
+
const auto * batch_logits = llama_get_logits(ctx);
|
|
389
|
+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
|
|
426
390
|
|
|
427
391
|
if (j == 0) {
|
|
428
392
|
tokens[batch_start] = token_org;
|
|
429
393
|
}
|
|
430
394
|
}
|
|
431
395
|
|
|
396
|
+
llama_batch_free(batch);
|
|
397
|
+
|
|
432
398
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
433
399
|
|
|
434
400
|
if (i == 0) {
|
|
435
401
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
|
436
|
-
|
|
402
|
+
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
|
437
403
|
int total_seconds = (int)(t_total * n_chunk);
|
|
438
404
|
if (total_seconds >= 60*60) {
|
|
439
|
-
|
|
405
|
+
LOG("%d hours ", total_seconds / (60*60));
|
|
440
406
|
total_seconds = total_seconds % (60*60);
|
|
441
407
|
}
|
|
442
|
-
|
|
408
|
+
LOG("%.2f minutes\n", total_seconds / 60.0);
|
|
443
409
|
}
|
|
444
410
|
|
|
445
|
-
//
|
|
411
|
+
//LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
|
|
446
412
|
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
|
|
447
|
-
|
|
448
413
|
// Calculate probability of next token, given the previous ones.
|
|
449
414
|
const std::vector<float> tok_logits(
|
|
450
|
-
logits.begin() + (j + 0) * n_vocab,
|
|
451
|
-
logits.begin() + (j + 1) * n_vocab);
|
|
415
|
+
logits.begin() + size_t(j + 0) * n_vocab,
|
|
416
|
+
logits.begin() + size_t(j + 1) * n_vocab);
|
|
452
417
|
|
|
453
418
|
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
|
454
419
|
logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
|
|
@@ -459,18 +424,17 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|
|
459
424
|
}
|
|
460
425
|
// perplexity is e^(average negative log-likelihood)
|
|
461
426
|
if (params.ppl_output_type == 0) {
|
|
462
|
-
|
|
427
|
+
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
|
463
428
|
} else {
|
|
464
|
-
|
|
429
|
+
LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
|
|
465
430
|
}
|
|
466
|
-
fflush(stdout);
|
|
467
431
|
}
|
|
468
|
-
|
|
432
|
+
LOG("\n");
|
|
469
433
|
|
|
470
434
|
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
|
471
435
|
}
|
|
472
436
|
|
|
473
|
-
static results_perplexity perplexity(llama_context * ctx, const
|
|
437
|
+
static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
|
|
474
438
|
if (params.ppl_stride > 0) {
|
|
475
439
|
return perplexity_v2(ctx, params);
|
|
476
440
|
}
|
|
@@ -480,33 +444,33 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
480
444
|
// Output: `perplexity: 13.5106 [114/114]`
|
|
481
445
|
// BOS tokens will be added for each chunk before eval
|
|
482
446
|
|
|
483
|
-
const bool add_bos =
|
|
484
|
-
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx))
|
|
447
|
+
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
|
448
|
+
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
|
485
449
|
|
|
486
450
|
std::ofstream logits_stream;
|
|
487
451
|
if (!params.logits_file.empty()) {
|
|
488
452
|
logits_stream.open(params.logits_file.c_str(), std::ios::binary);
|
|
489
453
|
if (!logits_stream.is_open()) {
|
|
490
|
-
|
|
454
|
+
LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
|
|
491
455
|
return {};
|
|
492
456
|
}
|
|
493
|
-
|
|
457
|
+
LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
|
|
494
458
|
logits_stream.write("_logits_", 8);
|
|
495
459
|
logits_stream.write(reinterpret_cast<const char *>(&n_ctx), sizeof(n_ctx));
|
|
496
460
|
}
|
|
497
461
|
|
|
498
462
|
auto tim1 = std::chrono::high_resolution_clock::now();
|
|
499
|
-
|
|
463
|
+
LOG_INF("%s: tokenizing the input ..\n", __func__);
|
|
500
464
|
|
|
501
|
-
std::vector<llama_token> tokens =
|
|
465
|
+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
|
|
502
466
|
|
|
503
467
|
auto tim2 = std::chrono::high_resolution_clock::now();
|
|
504
|
-
|
|
468
|
+
LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
|
505
469
|
|
|
506
470
|
if (int(tokens.size()) < 2*n_ctx) {
|
|
507
|
-
|
|
471
|
+
LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
|
|
508
472
|
n_ctx);
|
|
509
|
-
|
|
473
|
+
LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
|
|
510
474
|
return {std::move(tokens), 0., {}, {}};
|
|
511
475
|
}
|
|
512
476
|
|
|
@@ -519,9 +483,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
519
483
|
const int n_chunk_max = tokens.size() / n_ctx;
|
|
520
484
|
|
|
521
485
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
|
522
|
-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
523
486
|
const int n_batch = params.n_batch;
|
|
524
487
|
|
|
488
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
489
|
+
|
|
525
490
|
int count = 0;
|
|
526
491
|
double nll = 0.0;
|
|
527
492
|
double nll2 = 0.0;
|
|
@@ -536,10 +501,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
536
501
|
|
|
537
502
|
std::vector<float> logits;
|
|
538
503
|
if (num_batches > 1) {
|
|
539
|
-
logits.reserve((
|
|
504
|
+
logits.reserve(size_t(n_ctx) * n_vocab);
|
|
540
505
|
}
|
|
541
506
|
|
|
542
|
-
|
|
507
|
+
LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
|
543
508
|
|
|
544
509
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
|
545
510
|
|
|
@@ -612,13 +577,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
612
577
|
}
|
|
613
578
|
|
|
614
579
|
if (llama_decode(ctx, batch)) {
|
|
615
|
-
|
|
580
|
+
LOG_INF("%s : failed to eval\n", __func__);
|
|
616
581
|
return {tokens, -1, logit_history, prob_history};
|
|
617
582
|
}
|
|
618
583
|
|
|
619
584
|
if (num_batches > 1 && n_outputs > 0) {
|
|
620
585
|
const auto * batch_logits = llama_get_logits(ctx);
|
|
621
|
-
logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
|
|
586
|
+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
|
|
622
587
|
}
|
|
623
588
|
}
|
|
624
589
|
|
|
@@ -627,13 +592,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
627
592
|
llama_synchronize(ctx);
|
|
628
593
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
629
594
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
|
630
|
-
|
|
595
|
+
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
|
631
596
|
int total_seconds = (int)(t_total*n_chunk/n_seq);
|
|
632
597
|
if (total_seconds >= 60*60) {
|
|
633
|
-
|
|
598
|
+
LOG("%d hours ", total_seconds / (60*60));
|
|
634
599
|
total_seconds = total_seconds % (60*60);
|
|
635
600
|
}
|
|
636
|
-
|
|
601
|
+
LOG("%.2f minutes\n", total_seconds / 60.0);
|
|
637
602
|
}
|
|
638
603
|
|
|
639
604
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
|
@@ -655,19 +620,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
655
620
|
|
|
656
621
|
// perplexity is e^(average negative log-likelihood)
|
|
657
622
|
if (params.ppl_output_type == 0) {
|
|
658
|
-
|
|
623
|
+
LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
|
|
659
624
|
} else {
|
|
660
625
|
double av = nll/count;
|
|
661
626
|
double av2 = nll2/count - av*av;
|
|
662
|
-
if (av2 > 0)
|
|
663
|
-
|
|
627
|
+
if (av2 > 0) {
|
|
628
|
+
av2 = sqrt(av2/(count-1));
|
|
629
|
+
}
|
|
630
|
+
LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
|
664
631
|
}
|
|
665
632
|
}
|
|
666
|
-
fflush(stdout);
|
|
667
633
|
|
|
668
634
|
logits.clear();
|
|
669
635
|
}
|
|
670
|
-
|
|
636
|
+
LOG("\n");
|
|
671
637
|
|
|
672
638
|
nll2 /= count;
|
|
673
639
|
nll /= count;
|
|
@@ -675,9 +641,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
675
641
|
nll2 -= nll * nll;
|
|
676
642
|
if (nll2 > 0) {
|
|
677
643
|
nll2 = sqrt(nll2/(count-1));
|
|
678
|
-
|
|
644
|
+
LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
|
|
679
645
|
} else {
|
|
680
|
-
|
|
646
|
+
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
|
|
681
647
|
}
|
|
682
648
|
|
|
683
649
|
llama_batch_free(batch);
|
|
@@ -685,10 +651,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|
|
685
651
|
return {tokens, ppl, logit_history, prob_history};
|
|
686
652
|
}
|
|
687
653
|
|
|
688
|
-
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits,
|
|
654
|
+
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
|
|
689
655
|
int prev_outputs = 0;
|
|
690
|
-
for (
|
|
691
|
-
const
|
|
656
|
+
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
|
|
657
|
+
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
|
|
692
658
|
|
|
693
659
|
llama_batch batch_view = {
|
|
694
660
|
n_tokens,
|
|
@@ -698,12 +664,11 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|
|
698
664
|
batch.n_seq_id + i,
|
|
699
665
|
batch.seq_id + i,
|
|
700
666
|
batch.logits + i,
|
|
701
|
-
0, 0, 0, // unused
|
|
702
667
|
};
|
|
703
668
|
|
|
704
669
|
const int ret = llama_decode(ctx, batch_view);
|
|
705
670
|
if (ret != 0) {
|
|
706
|
-
|
|
671
|
+
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
|
707
672
|
return false;
|
|
708
673
|
}
|
|
709
674
|
|
|
@@ -712,7 +677,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|
|
712
677
|
n_outputs += batch_view.logits[i] != 0;
|
|
713
678
|
}
|
|
714
679
|
|
|
715
|
-
memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
|
|
680
|
+
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
|
716
681
|
|
|
717
682
|
prev_outputs += n_outputs;
|
|
718
683
|
}
|
|
@@ -727,7 +692,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
|
|
|
727
692
|
if (eval_results.size() != eval_pairs.size()) {
|
|
728
693
|
eval_results.resize(eval_pairs.size());
|
|
729
694
|
}
|
|
730
|
-
if (eval_pairs.empty())
|
|
695
|
+
if (eval_pairs.empty()) {
|
|
696
|
+
return;
|
|
697
|
+
}
|
|
731
698
|
|
|
732
699
|
size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
|
|
733
700
|
|
|
@@ -735,11 +702,13 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
|
|
|
735
702
|
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
|
|
736
703
|
float local_logprobs[K_TOKEN_CHUNK];
|
|
737
704
|
while (true) {
|
|
738
|
-
size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
|
|
739
|
-
if (first >= eval_results.size())
|
|
740
|
-
|
|
705
|
+
const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
|
|
706
|
+
if (first >= eval_results.size()) {
|
|
707
|
+
break;
|
|
708
|
+
}
|
|
709
|
+
const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
|
|
741
710
|
for (size_t i = first; i < last; ++i) {
|
|
742
|
-
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
|
|
711
|
+
const auto * logits = batch_logits + eval_pairs[i].first * n_vocab;
|
|
743
712
|
float max_logit = logits[0];
|
|
744
713
|
for (int j = 1; j < n_vocab; ++j) {
|
|
745
714
|
max_logit = std::max(max_logit, logits[j]);
|
|
@@ -762,7 +731,7 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
|
|
|
762
731
|
}
|
|
763
732
|
}
|
|
764
733
|
|
|
765
|
-
static void hellaswag_score(llama_context * ctx, const
|
|
734
|
+
static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|
766
735
|
// Calculates hellaswag score (acc_norm) from prompt
|
|
767
736
|
//
|
|
768
737
|
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
|
|
@@ -789,15 +758,15 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
789
758
|
}
|
|
790
759
|
|
|
791
760
|
if (prompt_lines.size() % 6 != 0) {
|
|
792
|
-
|
|
761
|
+
LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n", __func__);
|
|
793
762
|
return;
|
|
794
763
|
}
|
|
795
764
|
|
|
796
765
|
size_t hs_task_count = prompt_lines.size()/6;
|
|
797
|
-
|
|
766
|
+
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
|
798
767
|
|
|
799
768
|
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
|
|
800
|
-
|
|
769
|
+
LOG_INF("================================= is_spm = %d\n", is_spm);
|
|
801
770
|
|
|
802
771
|
// The tasks should be randomized so the score stabilizes quickly.
|
|
803
772
|
bool randomize_tasks = true;
|
|
@@ -824,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
824
793
|
std::vector<llama_token> seq_tokens[4];
|
|
825
794
|
};
|
|
826
795
|
|
|
827
|
-
|
|
796
|
+
LOG_INF("%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
|
|
828
797
|
|
|
829
798
|
// Select and read data from prompt lines
|
|
830
799
|
std::vector<hs_data_t> hs_data(hs_task_count);
|
|
@@ -843,7 +812,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
843
812
|
hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
|
|
844
813
|
for (size_t j = 0; j < 4; j++) {
|
|
845
814
|
hs_cur.ending[j] = prompt_lines[idx*6+2+j];
|
|
846
|
-
hs_cur.seq_tokens[j] =
|
|
815
|
+
hs_cur.seq_tokens[j] = common_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
|
|
847
816
|
}
|
|
848
817
|
|
|
849
818
|
// determine the common prefix of the endings
|
|
@@ -870,16 +839,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
870
839
|
}
|
|
871
840
|
}
|
|
872
841
|
|
|
873
|
-
|
|
842
|
+
LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
|
|
874
843
|
|
|
875
|
-
|
|
844
|
+
LOG("\ntask\tacc_norm\n");
|
|
876
845
|
|
|
877
846
|
double acc = 0.0f;
|
|
878
847
|
|
|
879
|
-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
880
848
|
const int n_ctx = llama_n_ctx(ctx);
|
|
881
849
|
const int n_batch = params.n_batch;
|
|
882
850
|
|
|
851
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
852
|
+
|
|
883
853
|
const int max_tasks_per_batch = 32;
|
|
884
854
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
885
855
|
|
|
@@ -887,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
887
857
|
|
|
888
858
|
std::vector<float> tok_logits(n_vocab);
|
|
889
859
|
// TODO: this could be made smaller; it's currently the worst-case size
|
|
890
|
-
std::vector<float> batch_logits(
|
|
860
|
+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
|
891
861
|
|
|
892
862
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
893
863
|
std::vector<float> eval_results;
|
|
@@ -899,7 +869,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
899
869
|
size_t i1 = i0;
|
|
900
870
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
|
901
871
|
|
|
902
|
-
|
|
872
|
+
common_batch_clear(batch);
|
|
903
873
|
|
|
904
874
|
// batch as much tasks as possible into the available context
|
|
905
875
|
// each task has 4 unique sequence ids - one for each ending
|
|
@@ -915,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
915
885
|
}
|
|
916
886
|
|
|
917
887
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
|
918
|
-
|
|
888
|
+
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
|
919
889
|
}
|
|
920
890
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
921
891
|
n_logits += 1;
|
|
@@ -925,7 +895,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
925
895
|
// TODO: don't evaluate the last token of each sequence
|
|
926
896
|
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
|
|
927
897
|
const bool needs_logits = i < seq_tokens_size - 1;
|
|
928
|
-
|
|
898
|
+
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
|
929
899
|
n_logits += needs_logits;
|
|
930
900
|
}
|
|
931
901
|
}
|
|
@@ -940,7 +910,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
940
910
|
}
|
|
941
911
|
|
|
942
912
|
if (i0 == i1) {
|
|
943
|
-
|
|
913
|
+
LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
|
|
944
914
|
return;
|
|
945
915
|
}
|
|
946
916
|
|
|
@@ -948,7 +918,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
948
918
|
|
|
949
919
|
// decode all tasks [i0, i1)
|
|
950
920
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
|
951
|
-
|
|
921
|
+
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
|
952
922
|
return;
|
|
953
923
|
}
|
|
954
924
|
|
|
@@ -974,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
974
944
|
auto & hs_cur = hs_data[i];
|
|
975
945
|
|
|
976
946
|
// get the logits of the last token of the common prefix
|
|
977
|
-
std::memcpy(tok_logits.data(), batch_logits.data() +
|
|
947
|
+
std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float));
|
|
978
948
|
|
|
979
949
|
const auto first_probs = softmax(tok_logits);
|
|
980
950
|
|
|
@@ -998,7 +968,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
998
968
|
}
|
|
999
969
|
}
|
|
1000
970
|
|
|
1001
|
-
//
|
|
971
|
+
//LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
|
|
1002
972
|
|
|
1003
973
|
// If the gold ending got the maximum logprobe add one accuracy point
|
|
1004
974
|
if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
|
|
@@ -1006,8 +976,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1006
976
|
}
|
|
1007
977
|
|
|
1008
978
|
// Print the accumulated accuracy mean x 100
|
|
1009
|
-
|
|
1010
|
-
fflush(stdout);
|
|
979
|
+
LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
|
|
1011
980
|
}
|
|
1012
981
|
|
|
1013
982
|
i0 = i1 - 1;
|
|
@@ -1015,7 +984,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1015
984
|
|
|
1016
985
|
llama_batch_free(batch);
|
|
1017
986
|
|
|
1018
|
-
|
|
987
|
+
LOG("\n");
|
|
1019
988
|
}
|
|
1020
989
|
|
|
1021
990
|
struct winogrande_entry {
|
|
@@ -1059,7 +1028,7 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
|
|
|
1059
1028
|
}
|
|
1060
1029
|
}
|
|
1061
1030
|
if (ipos != 4) {
|
|
1062
|
-
|
|
1031
|
+
LOG_ERR("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
|
|
1063
1032
|
continue;
|
|
1064
1033
|
}
|
|
1065
1034
|
auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
|
|
@@ -1073,13 +1042,13 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
|
|
|
1073
1042
|
if (sentence[where] == '_') break;
|
|
1074
1043
|
}
|
|
1075
1044
|
if (where == int(sentence.size())) {
|
|
1076
|
-
|
|
1045
|
+
LOG_ERR("%s: no _ in <%s>\n", __func__, sentence.c_str());
|
|
1077
1046
|
continue;
|
|
1078
1047
|
}
|
|
1079
1048
|
std::istringstream stream(answer.c_str());
|
|
1080
1049
|
int i_answer; stream >> i_answer;
|
|
1081
1050
|
if (stream.fail() || i_answer < 1 || i_answer > 2) {
|
|
1082
|
-
|
|
1051
|
+
LOG_ERR("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
|
|
1083
1052
|
continue;
|
|
1084
1053
|
}
|
|
1085
1054
|
result.emplace_back();
|
|
@@ -1102,20 +1071,20 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
|
|
|
1102
1071
|
* 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
|
|
1103
1072
|
*
|
|
1104
1073
|
*/
|
|
1105
|
-
static void winogrande_score(llama_context * ctx, const
|
|
1074
|
+
static void winogrande_score(llama_context * ctx, const common_params & params) {
|
|
1106
1075
|
|
|
1107
1076
|
constexpr int k_min_trailing_ctx = 3;
|
|
1108
1077
|
|
|
1109
1078
|
auto data = load_winogrande_from_csv(params.prompt);
|
|
1110
1079
|
if (data.empty()) {
|
|
1111
|
-
|
|
1080
|
+
LOG_ERR("%s: no tasks\n", __func__);
|
|
1112
1081
|
return;
|
|
1113
1082
|
}
|
|
1114
1083
|
|
|
1115
|
-
|
|
1084
|
+
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, data.size());
|
|
1116
1085
|
|
|
1117
1086
|
if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
|
|
1118
|
-
|
|
1087
|
+
LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
|
|
1119
1088
|
std::mt19937 rng(1);
|
|
1120
1089
|
std::vector<int> aux(data.size());
|
|
1121
1090
|
for (int i = 0; i < int(data.size()); ++i) {
|
|
@@ -1133,11 +1102,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1133
1102
|
data = std::move(selected);
|
|
1134
1103
|
}
|
|
1135
1104
|
|
|
1136
|
-
|
|
1105
|
+
LOG_INF("%s : tokenizing selected tasks\n", __func__);
|
|
1137
1106
|
|
|
1138
1107
|
for (auto & task : data) {
|
|
1139
|
-
task.seq_tokens[0] =
|
|
1140
|
-
task.seq_tokens[1] =
|
|
1108
|
+
task.seq_tokens[0] = common_tokenize(ctx, task.first + task.choices[0] + task.second, true);
|
|
1109
|
+
task.seq_tokens[1] = common_tokenize(ctx, task.first + task.choices[1] + task.second, true);
|
|
1141
1110
|
|
|
1142
1111
|
task.common_prefix = 0;
|
|
1143
1112
|
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
|
|
@@ -1152,16 +1121,17 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1152
1121
|
task.seq_tokens[0].size() - task.common_prefix +
|
|
1153
1122
|
task.seq_tokens[1].size() - task.common_prefix;
|
|
1154
1123
|
|
|
1155
|
-
task.n_base1 =
|
|
1156
|
-
task.n_base2 =
|
|
1124
|
+
task.n_base1 = common_tokenize(ctx, task.first + task.choices[0], true).size();
|
|
1125
|
+
task.n_base2 = common_tokenize(ctx, task.first + task.choices[1], true).size();
|
|
1157
1126
|
}
|
|
1158
1127
|
|
|
1159
|
-
|
|
1128
|
+
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
|
|
1160
1129
|
|
|
1161
|
-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1162
1130
|
const int n_ctx = llama_n_ctx(ctx);
|
|
1163
1131
|
const int n_batch = params.n_batch;
|
|
1164
1132
|
|
|
1133
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1134
|
+
|
|
1165
1135
|
const int max_tasks_per_batch = 128;
|
|
1166
1136
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
1167
1137
|
|
|
@@ -1169,7 +1139,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1169
1139
|
|
|
1170
1140
|
std::vector<float> tok_logits(n_vocab);
|
|
1171
1141
|
// TODO: this could be made smaller; it's currently the worst-case size
|
|
1172
|
-
std::vector<float> batch_logits(
|
|
1142
|
+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
|
1173
1143
|
|
|
1174
1144
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
1175
1145
|
std::vector<float> eval_results;
|
|
@@ -1184,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1184
1154
|
size_t i1 = i0;
|
|
1185
1155
|
size_t i_logits = 0;
|
|
1186
1156
|
|
|
1187
|
-
|
|
1157
|
+
common_batch_clear(batch);
|
|
1188
1158
|
|
|
1189
1159
|
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
|
|
1190
1160
|
int n_logits = 0;
|
|
@@ -1194,7 +1164,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1194
1164
|
}
|
|
1195
1165
|
|
|
1196
1166
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
|
1197
|
-
|
|
1167
|
+
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
|
1198
1168
|
}
|
|
1199
1169
|
batch.logits[batch.n_tokens - 1] = true;
|
|
1200
1170
|
n_logits += 1;
|
|
@@ -1202,7 +1172,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1202
1172
|
for (int s = 0; s < 2; ++s) {
|
|
1203
1173
|
// TODO: end before the last token, no need to predict past the end of the sequences
|
|
1204
1174
|
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
|
|
1205
|
-
|
|
1175
|
+
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
|
|
1206
1176
|
n_logits += 1;
|
|
1207
1177
|
}
|
|
1208
1178
|
}
|
|
@@ -1217,7 +1187,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1217
1187
|
}
|
|
1218
1188
|
|
|
1219
1189
|
if (i0 == i1) {
|
|
1220
|
-
|
|
1190
|
+
LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
|
|
1221
1191
|
return;
|
|
1222
1192
|
}
|
|
1223
1193
|
|
|
@@ -1225,7 +1195,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1225
1195
|
|
|
1226
1196
|
// decode all tasks [i0, i1)
|
|
1227
1197
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
|
1228
|
-
|
|
1198
|
+
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
|
1229
1199
|
return;
|
|
1230
1200
|
}
|
|
1231
1201
|
|
|
@@ -1285,20 +1255,20 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|
|
1285
1255
|
++n_done;
|
|
1286
1256
|
|
|
1287
1257
|
// print the accumulated accuracy mean x 100
|
|
1288
|
-
|
|
1289
|
-
fflush(stdout);
|
|
1258
|
+
LOG("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
|
|
1290
1259
|
}
|
|
1291
1260
|
|
|
1292
1261
|
i0 = i1 - 1;
|
|
1293
1262
|
}
|
|
1294
1263
|
|
|
1295
|
-
|
|
1264
|
+
LOG("\n");
|
|
1296
1265
|
|
|
1297
1266
|
if (n_done < 100) return;
|
|
1298
1267
|
|
|
1299
1268
|
const float p = 1.f*n_correct/n_done;
|
|
1300
1269
|
const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
|
|
1301
|
-
|
|
1270
|
+
|
|
1271
|
+
LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
|
|
1302
1272
|
}
|
|
1303
1273
|
|
|
1304
1274
|
static bool deserialize_string(std::istream & in, std::string & str) {
|
|
@@ -1347,7 +1317,7 @@ struct multiple_choice_task {
|
|
|
1347
1317
|
static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
|
|
1348
1318
|
if (task.question.empty() || task.mc1.answers.empty()) {
|
|
1349
1319
|
if (log_error) {
|
|
1350
|
-
|
|
1320
|
+
LOG_ERR("%s: found bad task with empty question and/or answers\n", __func__);
|
|
1351
1321
|
}
|
|
1352
1322
|
return false;
|
|
1353
1323
|
}
|
|
@@ -1355,11 +1325,11 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
|
|
|
1355
1325
|
for (auto& answer : task.mc1.answers) {
|
|
1356
1326
|
if (answer.empty()) {
|
|
1357
1327
|
if (log_error) {
|
|
1358
|
-
|
|
1328
|
+
LOG_ERR("%s: found empty answer\n", __func__);
|
|
1359
1329
|
}
|
|
1360
1330
|
return false;
|
|
1361
1331
|
}
|
|
1362
|
-
task.seq_tokens.emplace_back(::
|
|
1332
|
+
task.seq_tokens.emplace_back(::common_tokenize(ctx, task.question + " " + answer, true));
|
|
1363
1333
|
}
|
|
1364
1334
|
auto min_len = task.seq_tokens.front().size();
|
|
1365
1335
|
for (auto& seq : task.seq_tokens) {
|
|
@@ -1403,20 +1373,20 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
|
|
|
1403
1373
|
// git@hf.co:datasets/Stevross/mmlu
|
|
1404
1374
|
// https://huggingface.co/datasets/truthful_qa
|
|
1405
1375
|
//
|
|
1406
|
-
static void multiple_choice_score(llama_context * ctx, const
|
|
1376
|
+
static void multiple_choice_score(llama_context * ctx, const common_params & params) {
|
|
1407
1377
|
|
|
1408
1378
|
std::istringstream strstream(params.prompt);
|
|
1409
1379
|
uint32_t n_task;
|
|
1410
1380
|
strstream.read((char *)&n_task, sizeof(n_task));
|
|
1411
1381
|
if (strstream.fail() || n_task == 0) {
|
|
1412
|
-
|
|
1382
|
+
LOG_ERR("%s: no tasks\n", __func__);
|
|
1413
1383
|
return;
|
|
1414
1384
|
}
|
|
1415
|
-
|
|
1385
|
+
LOG_INF("%s: there are %u tasks in prompt\n", __func__, n_task);
|
|
1416
1386
|
std::vector<uint32_t> task_pos(n_task);
|
|
1417
1387
|
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
|
|
1418
1388
|
if (strstream.fail()) {
|
|
1419
|
-
|
|
1389
|
+
LOG_ERR("%s: failed to read task positions from prompt\n", __func__);
|
|
1420
1390
|
return;
|
|
1421
1391
|
}
|
|
1422
1392
|
|
|
@@ -1424,21 +1394,21 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1424
1394
|
if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) {
|
|
1425
1395
|
// Use all tasks
|
|
1426
1396
|
tasks.resize(n_task);
|
|
1427
|
-
|
|
1397
|
+
LOG_INF("%s: reading tasks", __func__);
|
|
1428
1398
|
int n_dot = std::max((int) n_task/100, 1);
|
|
1429
1399
|
int i = 0;
|
|
1430
1400
|
for (auto& task : tasks) {
|
|
1431
1401
|
++i;
|
|
1432
1402
|
if (!task.deserialize(strstream)) {
|
|
1433
|
-
|
|
1403
|
+
LOG_ERR("%s: failed to read task %d of %u\n", __func__, i, n_task);
|
|
1434
1404
|
return;
|
|
1435
1405
|
}
|
|
1436
|
-
if (i%n_dot == 0)
|
|
1406
|
+
if (i%n_dot == 0) LOG(".");
|
|
1437
1407
|
}
|
|
1438
|
-
|
|
1408
|
+
LOG("done\n");
|
|
1439
1409
|
}
|
|
1440
1410
|
else {
|
|
1441
|
-
|
|
1411
|
+
LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task);
|
|
1442
1412
|
std::mt19937 rng(1);
|
|
1443
1413
|
std::vector<int> aux(n_task);
|
|
1444
1414
|
for (uint32_t i = 0; i < n_task; ++i) aux[i] = i;
|
|
@@ -1451,18 +1421,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1451
1421
|
aux.pop_back();
|
|
1452
1422
|
strstream.seekg(task_pos[idx], std::ios::beg);
|
|
1453
1423
|
if (!task.deserialize(strstream)) {
|
|
1454
|
-
|
|
1424
|
+
LOG_ERR("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]);
|
|
1455
1425
|
return;
|
|
1456
1426
|
}
|
|
1457
1427
|
}
|
|
1458
1428
|
n_task = params.multiple_choice_tasks;
|
|
1459
1429
|
}
|
|
1460
1430
|
|
|
1461
|
-
|
|
1462
|
-
fflush(stdout);
|
|
1431
|
+
LOG_INF("%s: preparing task data", __func__);
|
|
1463
1432
|
if (n_task > 500) {
|
|
1464
|
-
|
|
1465
|
-
fflush(stdout);
|
|
1433
|
+
LOG("...");
|
|
1466
1434
|
std::atomic<int> counter(0);
|
|
1467
1435
|
std::atomic<int> n_bad(0);
|
|
1468
1436
|
auto prepare = [&counter, &n_bad, &tasks, ctx] () {
|
|
@@ -1486,11 +1454,10 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1486
1454
|
for (auto& w : workers) w = std::thread(prepare);
|
|
1487
1455
|
prepare();
|
|
1488
1456
|
for (auto& w : workers) w.join();
|
|
1489
|
-
|
|
1490
|
-
fflush(stdout);
|
|
1457
|
+
LOG("done\n");
|
|
1491
1458
|
int nbad = n_bad;
|
|
1492
1459
|
if (nbad > 0) {
|
|
1493
|
-
|
|
1460
|
+
LOG_ERR("%s: found %d malformed tasks\n", __func__, nbad);
|
|
1494
1461
|
return;
|
|
1495
1462
|
}
|
|
1496
1463
|
} else {
|
|
@@ -1502,28 +1469,28 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1502
1469
|
return;
|
|
1503
1470
|
}
|
|
1504
1471
|
if (i_task%n_dot == 0) {
|
|
1505
|
-
|
|
1506
|
-
fflush(stdout);
|
|
1472
|
+
LOG(".");
|
|
1507
1473
|
}
|
|
1508
1474
|
}
|
|
1509
|
-
|
|
1475
|
+
LOG("done\n");
|
|
1510
1476
|
}
|
|
1511
1477
|
|
|
1512
|
-
|
|
1478
|
+
LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size());
|
|
1513
1479
|
|
|
1514
|
-
|
|
1480
|
+
LOG("\ntask\tacc_norm\n");
|
|
1515
1481
|
|
|
1516
|
-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1517
1482
|
const int n_ctx = llama_n_ctx(ctx);
|
|
1518
1483
|
const int n_batch = params.n_batch;
|
|
1519
1484
|
|
|
1485
|
+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1486
|
+
|
|
1520
1487
|
const int max_tasks_per_batch = 32;
|
|
1521
1488
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
|
|
1522
1489
|
|
|
1523
1490
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
|
1524
1491
|
|
|
1525
1492
|
std::vector<float> tok_logits(n_vocab);
|
|
1526
|
-
std::vector<float> batch_logits(
|
|
1493
|
+
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
|
|
1527
1494
|
|
|
1528
1495
|
std::vector<std::pair<size_t, llama_token>> eval_pairs;
|
|
1529
1496
|
std::vector<float> eval_results;
|
|
@@ -1540,7 +1507,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1540
1507
|
size_t i1 = i0;
|
|
1541
1508
|
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
|
|
1542
1509
|
|
|
1543
|
-
|
|
1510
|
+
common_batch_clear(batch);
|
|
1544
1511
|
|
|
1545
1512
|
// batch as much tasks as possible into the available context
|
|
1546
1513
|
// each task has 4 unique sequence ids - one for each ending
|
|
@@ -1563,7 +1530,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1563
1530
|
|
|
1564
1531
|
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
|
|
1565
1532
|
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
|
1566
|
-
|
|
1533
|
+
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
|
1567
1534
|
}
|
|
1568
1535
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
|
1569
1536
|
n_logits += 1;
|
|
@@ -1573,7 +1540,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1573
1540
|
// TODO: don't evaluate the last token of each sequence
|
|
1574
1541
|
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
|
|
1575
1542
|
const bool needs_logits = i < seq_tokens_size - 1;
|
|
1576
|
-
|
|
1543
|
+
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
|
|
1577
1544
|
n_logits += needs_logits;
|
|
1578
1545
|
}
|
|
1579
1546
|
}
|
|
@@ -1590,7 +1557,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1590
1557
|
}
|
|
1591
1558
|
|
|
1592
1559
|
if (i0 == i1) {
|
|
1593
|
-
|
|
1560
|
+
LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0);
|
|
1594
1561
|
return;
|
|
1595
1562
|
}
|
|
1596
1563
|
|
|
@@ -1598,7 +1565,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1598
1565
|
|
|
1599
1566
|
// decode all tasks [i0, i1)
|
|
1600
1567
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
|
1601
|
-
|
|
1568
|
+
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
|
1602
1569
|
return;
|
|
1603
1570
|
}
|
|
1604
1571
|
|
|
@@ -1622,16 +1589,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1622
1589
|
// compute the logprobs for each ending of the decoded tasks
|
|
1623
1590
|
for (size_t i = i0; i < i1; ++i) {
|
|
1624
1591
|
auto & cur_task = tasks[i];
|
|
1625
|
-
//
|
|
1592
|
+
//LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str());
|
|
1626
1593
|
//for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) {
|
|
1627
1594
|
// if (cur_task.mc1.labels[j] == 1) {
|
|
1628
|
-
//
|
|
1595
|
+
// LOG("%d", j+1);
|
|
1629
1596
|
// }
|
|
1630
1597
|
//}
|
|
1631
|
-
//
|
|
1598
|
+
//LOG("\n common_prefix: %zu\n", cur_task.common_prefix);
|
|
1632
1599
|
|
|
1633
1600
|
// get the logits of the last token of the common prefix
|
|
1634
|
-
std::memcpy(tok_logits.data(), batch_logits.data() +
|
|
1601
|
+
std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float));
|
|
1635
1602
|
|
|
1636
1603
|
const auto first_probs = softmax(tok_logits);
|
|
1637
1604
|
|
|
@@ -1640,13 +1607,13 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1640
1607
|
size_t count = 1;
|
|
1641
1608
|
float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]);
|
|
1642
1609
|
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
|
|
1643
|
-
//
|
|
1610
|
+
//LOG(" %zu %g\n", ir, eval_results[ir]);
|
|
1644
1611
|
++count;
|
|
1645
1612
|
log_prob += eval_results[ir++];
|
|
1646
1613
|
}
|
|
1647
1614
|
cur_task.log_probs[s] = log_prob / count;
|
|
1648
|
-
//
|
|
1649
|
-
//
|
|
1615
|
+
//LOG(" Final: %g\n", log_prob / count);
|
|
1616
|
+
//LOG(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count);
|
|
1650
1617
|
}
|
|
1651
1618
|
|
|
1652
1619
|
// Find the ending with maximum logprob
|
|
@@ -1666,8 +1633,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1666
1633
|
++n_done;
|
|
1667
1634
|
|
|
1668
1635
|
// Print the accumulated accuracy mean x 100
|
|
1669
|
-
|
|
1670
|
-
fflush(stdout);
|
|
1636
|
+
LOG("%d\t%.8lf\n", n_done, 100.*n_correct/n_done);
|
|
1671
1637
|
}
|
|
1672
1638
|
|
|
1673
1639
|
i0 = i1 - 1;
|
|
@@ -1679,29 +1645,30 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
|
|
1679
1645
|
|
|
1680
1646
|
float p = 1.f*n_correct/n_done;
|
|
1681
1647
|
float sigma = sqrt(p*(1-p)/(n_done-1));
|
|
1682
|
-
|
|
1648
|
+
LOG("\n");
|
|
1649
|
+
LOG_INF("Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
|
|
1683
1650
|
p = 1.f*n_done/n_tot_answers;
|
|
1684
1651
|
sigma = sqrt(p*(1-p)/(n_done-1));
|
|
1685
|
-
|
|
1652
|
+
LOG_INF("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma);
|
|
1686
1653
|
|
|
1687
|
-
|
|
1654
|
+
LOG_INF("\n");
|
|
1688
1655
|
}
|
|
1689
1656
|
|
|
1690
|
-
static void kl_divergence(llama_context * ctx, const
|
|
1657
|
+
static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|
1691
1658
|
if (params.logits_file.empty()) {
|
|
1692
|
-
|
|
1659
|
+
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
|
|
1693
1660
|
return;
|
|
1694
1661
|
}
|
|
1695
1662
|
std::ifstream in(params.logits_file.c_str(), std::ios::binary);
|
|
1696
1663
|
if (!in) {
|
|
1697
|
-
|
|
1664
|
+
LOG_ERR("%s: failed to open %s\n", __func__, params.logits_file.c_str());
|
|
1698
1665
|
return;
|
|
1699
1666
|
}
|
|
1700
1667
|
{
|
|
1701
1668
|
char check[9]; check[8] = 0;
|
|
1702
1669
|
in.read(check, 8);
|
|
1703
1670
|
if (in.fail() || strncmp("_logits_", check, 8) != 0) {
|
|
1704
|
-
|
|
1671
|
+
LOG_ERR("%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
|
|
1705
1672
|
return;
|
|
1706
1673
|
}
|
|
1707
1674
|
}
|
|
@@ -1709,39 +1676,40 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
1709
1676
|
uint32_t n_ctx;
|
|
1710
1677
|
in.read((char *)&n_ctx, sizeof(n_ctx));
|
|
1711
1678
|
if (n_ctx > llama_n_ctx(ctx)) {
|
|
1712
|
-
|
|
1679
|
+
LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
|
|
1713
1680
|
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
|
|
1714
1681
|
}
|
|
1715
1682
|
|
|
1716
|
-
int n_vocab
|
|
1683
|
+
int n_vocab;
|
|
1684
|
+
int n_chunk;
|
|
1717
1685
|
in.read((char *)&n_vocab, sizeof(n_vocab));
|
|
1718
1686
|
in.read((char *)&n_chunk, sizeof(n_chunk));
|
|
1719
1687
|
if (in.fail()) {
|
|
1720
|
-
|
|
1688
|
+
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
|
1721
1689
|
return;
|
|
1722
1690
|
}
|
|
1723
1691
|
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
|
1724
|
-
|
|
1692
|
+
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
|
|
1725
1693
|
}
|
|
1726
1694
|
|
|
1727
|
-
std::vector<llama_token> tokens(n_ctx * n_chunk);
|
|
1695
|
+
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
|
|
1728
1696
|
if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
|
|
1729
|
-
|
|
1697
|
+
LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
|
|
1730
1698
|
return;
|
|
1731
1699
|
}
|
|
1732
1700
|
|
|
1733
1701
|
const int n_batch = params.n_batch;
|
|
1734
1702
|
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
|
|
1735
1703
|
const int nv = 2*((n_vocab + 1)/2) + 4;
|
|
1736
|
-
const bool add_bos =
|
|
1737
|
-
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx))
|
|
1704
|
+
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
|
|
1705
|
+
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
|
|
1738
1706
|
|
|
1739
1707
|
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
|
1740
1708
|
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
|
1741
1709
|
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
|
1742
1710
|
std::vector<float> logits;
|
|
1743
1711
|
if (num_batches > 1) {
|
|
1744
|
-
logits.reserve(n_ctx * n_vocab);
|
|
1712
|
+
logits.reserve(size_t(n_ctx) * n_vocab);
|
|
1745
1713
|
}
|
|
1746
1714
|
|
|
1747
1715
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
|
@@ -1775,13 +1743,15 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
1775
1743
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
1776
1744
|
|
|
1777
1745
|
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
|
1778
|
-
|
|
1746
|
+
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
|
|
1779
1747
|
return;
|
|
1780
1748
|
}
|
|
1781
1749
|
|
|
1782
1750
|
// clear the KV cache
|
|
1783
1751
|
llama_kv_cache_clear(ctx);
|
|
1784
1752
|
|
|
1753
|
+
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
|
1754
|
+
|
|
1785
1755
|
for (int j = 0; j < num_batches; ++j) {
|
|
1786
1756
|
const int batch_start = start + j * n_batch;
|
|
1787
1757
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
@@ -1794,9 +1764,14 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
1794
1764
|
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
1795
1765
|
}
|
|
1796
1766
|
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1767
|
+
common_batch_clear(batch);
|
|
1768
|
+
for (int i = 0; i < batch_size; i++) {
|
|
1769
|
+
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
|
1770
|
+
}
|
|
1771
|
+
|
|
1772
|
+
if (llama_decode(ctx, batch)) {
|
|
1773
|
+
LOG_ERR("%s : failed to eval\n", __func__);
|
|
1774
|
+
llama_batch_free(batch);
|
|
1800
1775
|
return;
|
|
1801
1776
|
}
|
|
1802
1777
|
|
|
@@ -1805,105 +1780,105 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
1805
1780
|
|
|
1806
1781
|
if (num_batches > 1) {
|
|
1807
1782
|
const auto * batch_logits = llama_get_logits(ctx);
|
|
1808
|
-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
|
1783
|
+
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
|
|
1809
1784
|
}
|
|
1810
1785
|
}
|
|
1811
1786
|
|
|
1787
|
+
llama_batch_free(batch);
|
|
1788
|
+
|
|
1812
1789
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
1813
1790
|
|
|
1814
1791
|
if (i == 0) {
|
|
1815
1792
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
|
1816
|
-
|
|
1793
|
+
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
|
1817
1794
|
int total_seconds = (int)(t_total * n_chunk);
|
|
1818
1795
|
if (total_seconds >= 60*60) {
|
|
1819
|
-
|
|
1796
|
+
LOG("%d hours ", total_seconds / (60*60));
|
|
1820
1797
|
total_seconds = total_seconds % (60*60);
|
|
1821
1798
|
}
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
|
1799
|
+
LOG("%.2f minutes\n", total_seconds / 60.0);
|
|
1825
1800
|
}
|
|
1801
|
+
LOG("\n");
|
|
1802
|
+
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
|
1826
1803
|
|
|
1827
1804
|
const int first = n_ctx/2;
|
|
1828
1805
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
|
1829
|
-
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
|
1806
|
+
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
|
1830
1807
|
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
|
|
1831
1808
|
p_diff_ptr += n_ctx - 1 - first;
|
|
1832
1809
|
kld_ptr += n_ctx - 1 - first;
|
|
1833
1810
|
|
|
1834
|
-
|
|
1811
|
+
LOG("%4d", i+1);
|
|
1835
1812
|
|
|
1836
1813
|
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
|
1837
1814
|
const double ppl_val = exp(log_ppl.first);
|
|
1838
1815
|
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
|
|
1839
|
-
|
|
1816
|
+
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
|
|
1840
1817
|
|
|
1841
1818
|
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
|
1842
1819
|
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
|
1843
1820
|
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
|
1844
1821
|
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
|
1845
|
-
|
|
1822
|
+
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
|
|
1846
1823
|
|
|
1847
1824
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
|
1848
|
-
|
|
1825
|
+
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
|
|
1849
1826
|
|
|
1850
1827
|
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
|
1851
1828
|
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
|
1852
1829
|
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
|
1853
|
-
|
|
1830
|
+
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
|
1854
1831
|
|
|
1855
1832
|
double p_top_val = 1.*kld.n_same_top/kld.count;
|
|
1856
1833
|
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
printf("\n");
|
|
1834
|
+
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
|
|
1860
1835
|
|
|
1861
|
-
|
|
1836
|
+
LOG("\n");
|
|
1862
1837
|
|
|
1863
1838
|
logits.clear();
|
|
1864
1839
|
}
|
|
1865
|
-
|
|
1840
|
+
LOG("\n");
|
|
1866
1841
|
|
|
1867
1842
|
if (kld.count < 100) return; // we do not wish to do statistics on so few values
|
|
1868
1843
|
|
|
1869
1844
|
std::sort(kld_values.begin(), kld_values.end());
|
|
1870
1845
|
std::sort(p_diff_values.begin(), p_diff_values.end());
|
|
1871
1846
|
|
|
1872
|
-
|
|
1847
|
+
LOG("====== Perplexity statistics ======\n");
|
|
1873
1848
|
|
|
1874
1849
|
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
|
1875
1850
|
const double ppl_val = exp(log_ppl.first);
|
|
1876
1851
|
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
|
|
1877
|
-
|
|
1852
|
+
LOG("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc);
|
|
1878
1853
|
|
|
1879
1854
|
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
|
1880
1855
|
const double ppl_base_val = exp(log_ppl_base.first);
|
|
1881
1856
|
const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 )
|
|
1882
|
-
|
|
1857
|
+
LOG("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc);
|
|
1883
1858
|
|
|
1884
1859
|
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
|
1885
|
-
//
|
|
1860
|
+
// LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov);
|
|
1886
1861
|
const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second);
|
|
1887
|
-
|
|
1862
|
+
LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor);
|
|
1888
1863
|
|
|
1889
1864
|
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
|
1890
1865
|
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
|
1891
|
-
|
|
1866
|
+
LOG("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc);
|
|
1892
1867
|
|
|
1893
1868
|
const double ppl_ratio_val = exp(log_ppl_ratio_val);
|
|
1894
1869
|
const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 )
|
|
1895
|
-
|
|
1870
|
+
LOG("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc);
|
|
1896
1871
|
|
|
1897
1872
|
const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov;
|
|
1898
1873
|
const double ppl_diff_val = ppl_val - ppl_base_val;
|
|
1899
1874
|
const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov);
|
|
1900
|
-
|
|
1875
|
+
LOG("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc);
|
|
1901
1876
|
|
|
1902
|
-
|
|
1877
|
+
LOG("\n");
|
|
1903
1878
|
|
|
1904
|
-
|
|
1879
|
+
LOG("====== KL divergence statistics ======\n");
|
|
1905
1880
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
|
1906
|
-
|
|
1881
|
+
LOG("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second);
|
|
1907
1882
|
auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
|
|
1908
1883
|
: kld_values[kld_values.size()/2];
|
|
1909
1884
|
|
|
@@ -1915,67 +1890,68 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|
|
1915
1890
|
return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)];
|
|
1916
1891
|
};
|
|
1917
1892
|
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1893
|
+
LOG("Maximum KLD: %10.6f\n", kld_values.back());
|
|
1894
|
+
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
|
|
1895
|
+
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
|
1896
|
+
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
|
1897
|
+
LOG("Median KLD: %10.6f\n", kld_median);
|
|
1898
|
+
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
|
|
1899
|
+
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
|
|
1900
|
+
LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f));
|
|
1901
|
+
LOG("Minimum KLD: %10.6f\n", kld_values.front());
|
|
1927
1902
|
|
|
1928
|
-
|
|
1903
|
+
LOG("\n");
|
|
1929
1904
|
|
|
1930
|
-
|
|
1905
|
+
LOG("====== Token probability statistics ======\n");
|
|
1931
1906
|
|
|
1932
1907
|
auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count);
|
|
1933
|
-
|
|
1908
|
+
LOG("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second);
|
|
1934
1909
|
|
|
1935
1910
|
auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1])
|
|
1936
1911
|
: p_diff_values[p_diff_values.size()/2];
|
|
1937
1912
|
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1913
|
+
LOG("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back());
|
|
1914
|
+
LOG("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f));
|
|
1915
|
+
LOG("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f));
|
|
1916
|
+
LOG("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f));
|
|
1917
|
+
LOG("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f));
|
|
1918
|
+
LOG("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f));
|
|
1919
|
+
LOG("Median Δp: %6.3lf%%\n", 100.0*p_diff_median);
|
|
1920
|
+
LOG("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f));
|
|
1921
|
+
LOG("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f));
|
|
1922
|
+
LOG(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f));
|
|
1923
|
+
LOG(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f));
|
|
1924
|
+
LOG(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f));
|
|
1925
|
+
LOG("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front());
|
|
1951
1926
|
|
|
1952
1927
|
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
|
1953
|
-
//
|
|
1928
|
+
// LOG("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second);
|
|
1954
1929
|
|
|
1955
1930
|
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
|
1956
1931
|
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
|
1957
|
-
|
|
1932
|
+
LOG("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
|
1958
1933
|
|
|
1959
1934
|
const double same_top_p = 1.0*kld.n_same_top/kld.count;
|
|
1960
|
-
|
|
1961
|
-
|
|
1935
|
+
LOG("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1)));
|
|
1962
1936
|
}
|
|
1963
1937
|
|
|
1964
1938
|
int main(int argc, char ** argv) {
|
|
1965
|
-
|
|
1939
|
+
common_params params;
|
|
1966
1940
|
|
|
1967
1941
|
params.n_ctx = 512;
|
|
1968
1942
|
params.logits_all = true;
|
|
1943
|
+
params.escape = false;
|
|
1969
1944
|
|
|
1970
|
-
if (!
|
|
1971
|
-
gpt_params_print_usage(argc, argv, params);
|
|
1945
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
|
|
1972
1946
|
return 1;
|
|
1973
1947
|
}
|
|
1974
1948
|
|
|
1949
|
+
common_init();
|
|
1950
|
+
|
|
1975
1951
|
const int32_t n_ctx = params.n_ctx;
|
|
1976
1952
|
|
|
1977
1953
|
if (n_ctx <= 0) {
|
|
1978
|
-
|
|
1954
|
+
LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__);
|
|
1979
1955
|
return 1;
|
|
1980
1956
|
}
|
|
1981
1957
|
|
|
@@ -2000,45 +1976,35 @@ int main(int argc, char ** argv) {
|
|
|
2000
1976
|
}
|
|
2001
1977
|
|
|
2002
1978
|
if (params.ppl_stride > 0) {
|
|
2003
|
-
|
|
1979
|
+
LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
|
2004
1980
|
params.n_ctx, params.n_ctx + params.ppl_stride/2);
|
|
2005
1981
|
params.n_ctx += params.ppl_stride/2;
|
|
2006
1982
|
}
|
|
2007
1983
|
|
|
2008
|
-
print_build_info();
|
|
2009
|
-
|
|
2010
|
-
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
2011
|
-
params.seed = time(NULL);
|
|
2012
|
-
}
|
|
2013
|
-
|
|
2014
|
-
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
|
2015
|
-
|
|
2016
|
-
std::mt19937 rng(params.seed);
|
|
2017
|
-
|
|
2018
1984
|
llama_backend_init();
|
|
2019
1985
|
llama_numa_init(params.numa);
|
|
2020
1986
|
|
|
2021
|
-
llama_model * model;
|
|
2022
|
-
llama_context * ctx;
|
|
2023
|
-
|
|
2024
1987
|
// load the model and apply lora adapter, if any
|
|
2025
|
-
|
|
1988
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
1989
|
+
|
|
1990
|
+
llama_model * model = llama_init.model;
|
|
1991
|
+
llama_context * ctx = llama_init.context;
|
|
2026
1992
|
if (model == NULL) {
|
|
2027
|
-
|
|
1993
|
+
LOG_ERR("%s: unable to load model\n", __func__);
|
|
2028
1994
|
return 1;
|
|
2029
1995
|
}
|
|
2030
1996
|
|
|
2031
1997
|
const int n_ctx_train = llama_n_ctx_train(model);
|
|
2032
1998
|
|
|
2033
1999
|
if (params.n_ctx > n_ctx_train) {
|
|
2034
|
-
|
|
2000
|
+
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
|
|
2035
2001
|
__func__, n_ctx_train, params.n_ctx);
|
|
2036
2002
|
}
|
|
2037
2003
|
|
|
2038
2004
|
// print system information
|
|
2039
2005
|
{
|
|
2040
|
-
|
|
2041
|
-
|
|
2006
|
+
LOG_INF("\n");
|
|
2007
|
+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
2042
2008
|
}
|
|
2043
2009
|
|
|
2044
2010
|
struct results_perplexity results;
|
|
@@ -2054,8 +2020,8 @@ int main(int argc, char ** argv) {
|
|
|
2054
2020
|
results = perplexity(ctx, params, n_ctx);
|
|
2055
2021
|
}
|
|
2056
2022
|
|
|
2057
|
-
|
|
2058
|
-
|
|
2023
|
+
LOG("\n");
|
|
2024
|
+
llama_perf_context_print(ctx);
|
|
2059
2025
|
|
|
2060
2026
|
llama_free(ctx);
|
|
2061
2027
|
llama_free_model(model);
|