@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -1,1515 +0,0 @@
|
|
|
1
|
-
#include "train.h"
|
|
2
|
-
#include "common.h"
|
|
3
|
-
|
|
4
|
-
#include <algorithm>
|
|
5
|
-
#include <random>
|
|
6
|
-
#include <sstream>
|
|
7
|
-
#include <functional>
|
|
8
|
-
#include <cstring>
|
|
9
|
-
|
|
10
|
-
struct random_normal_distribution {
|
|
11
|
-
std::mt19937 gen;
|
|
12
|
-
std::normal_distribution<float> rd;
|
|
13
|
-
float min;
|
|
14
|
-
float max;
|
|
15
|
-
};
|
|
16
|
-
|
|
17
|
-
struct random_uniform_distribution {
|
|
18
|
-
std::mt19937 gen;
|
|
19
|
-
std::uniform_real_distribution<float> rd;
|
|
20
|
-
};
|
|
21
|
-
|
|
22
|
-
struct train_state * init_train_state() {
|
|
23
|
-
struct train_state * state = new struct train_state;
|
|
24
|
-
state->train_its = 0;
|
|
25
|
-
state->train_samples = 0;
|
|
26
|
-
state->train_tokens = 0;
|
|
27
|
-
state->train_epochs = 0;
|
|
28
|
-
state->shuffle_samples_hash = 0;
|
|
29
|
-
state->shuffle_sample_count = 0;
|
|
30
|
-
state->shuffle_next_sample = 0;
|
|
31
|
-
state->shuffle_rng_state_current = "";
|
|
32
|
-
state->shuffle_rng_state_next = "";
|
|
33
|
-
|
|
34
|
-
state->opt = new struct ggml_opt_context;
|
|
35
|
-
state->opt->ctx = NULL;
|
|
36
|
-
state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
|
|
37
|
-
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
|
|
38
|
-
state->opt->loss_after = 0.0f;
|
|
39
|
-
|
|
40
|
-
return state;
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
void free_train_state(struct train_state * state) {
|
|
44
|
-
delete state->opt;
|
|
45
|
-
delete state;
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
struct random_normal_distribution * init_random_normal_distribution(
|
|
49
|
-
int seed, float mean, float std, float min, float max
|
|
50
|
-
) {
|
|
51
|
-
struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
|
|
52
|
-
rnd->gen = std::mt19937(seed);
|
|
53
|
-
rnd->rd = std::normal_distribution<float>{mean, std};
|
|
54
|
-
rnd->min = min;
|
|
55
|
-
rnd->max = max;
|
|
56
|
-
return rnd;
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) {
|
|
60
|
-
struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution));
|
|
61
|
-
rnd->gen = std::mt19937(seed);
|
|
62
|
-
rnd->rd = std::uniform_real_distribution<float>{min, max};
|
|
63
|
-
return rnd;
|
|
64
|
-
}
|
|
65
|
-
|
|
66
|
-
void free_random_normal_distribution (struct random_normal_distribution * rnd) {
|
|
67
|
-
free(rnd);
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
void free_random_uniform_distribution(struct random_uniform_distribution * rnd) {
|
|
71
|
-
free(rnd);
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
|
|
75
|
-
float scale = 1.0f; // xavier
|
|
76
|
-
switch (ggml_n_dims(tensor)) {
|
|
77
|
-
case 1:
|
|
78
|
-
scale /= sqrtf((float) tensor->ne[0]);
|
|
79
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
80
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
|
|
81
|
-
*dst = scale * frand_normal(rnd);
|
|
82
|
-
}
|
|
83
|
-
break;
|
|
84
|
-
case 2:
|
|
85
|
-
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
|
|
86
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
87
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
88
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
89
|
-
*dst = scale * frand_normal(rnd);
|
|
90
|
-
}
|
|
91
|
-
}
|
|
92
|
-
break;
|
|
93
|
-
case 3:
|
|
94
|
-
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
|
|
95
|
-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
96
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
97
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
98
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
|
99
|
-
*dst = scale * frand_normal(rnd);
|
|
100
|
-
}
|
|
101
|
-
}
|
|
102
|
-
}
|
|
103
|
-
break;
|
|
104
|
-
case 4:
|
|
105
|
-
scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]);
|
|
106
|
-
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
|
|
107
|
-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
108
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
109
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
110
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
|
|
111
|
-
*dst = scale * frand_normal(rnd);
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
}
|
|
115
|
-
}
|
|
116
|
-
break;
|
|
117
|
-
default:
|
|
118
|
-
die("Unsupported tensor->n_dims");
|
|
119
|
-
};
|
|
120
|
-
return tensor;
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) {
|
|
124
|
-
switch (ggml_n_dims(tensor)) {
|
|
125
|
-
case 1:
|
|
126
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
127
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]);
|
|
128
|
-
*dst = frand_uniform(rnd);
|
|
129
|
-
}
|
|
130
|
-
break;
|
|
131
|
-
case 2:
|
|
132
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
133
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
134
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
|
135
|
-
*dst = frand_uniform(rnd);
|
|
136
|
-
}
|
|
137
|
-
}
|
|
138
|
-
break;
|
|
139
|
-
case 3:
|
|
140
|
-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
141
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
142
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
143
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]);
|
|
144
|
-
*dst = frand_uniform(rnd);
|
|
145
|
-
}
|
|
146
|
-
}
|
|
147
|
-
}
|
|
148
|
-
break;
|
|
149
|
-
case 4:
|
|
150
|
-
for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
|
|
151
|
-
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
|
152
|
-
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
|
153
|
-
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
|
154
|
-
float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]);
|
|
155
|
-
*dst = frand_uniform(rnd);
|
|
156
|
-
}
|
|
157
|
-
}
|
|
158
|
-
}
|
|
159
|
-
}
|
|
160
|
-
break;
|
|
161
|
-
default:
|
|
162
|
-
die("Unsupported tensor->n_dims");
|
|
163
|
-
};
|
|
164
|
-
return tensor;
|
|
165
|
-
}
|
|
166
|
-
|
|
167
|
-
float frand() {
|
|
168
|
-
return (float)rand()/((float)(RAND_MAX) + 1.0f);
|
|
169
|
-
}
|
|
170
|
-
|
|
171
|
-
float frand_normal(struct random_normal_distribution * rnd) {
|
|
172
|
-
return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max);
|
|
173
|
-
}
|
|
174
|
-
|
|
175
|
-
float frand_uniform(struct random_uniform_distribution * rnd) {
|
|
176
|
-
return rnd->rd(rnd->gen);
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
int clamp(const int v, const int min, const int max) {
|
|
180
|
-
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
float fclamp(const float v, const float min, const float max) {
|
|
184
|
-
return ((v < min) ? (min) : (v > max) ? (max) : v);
|
|
185
|
-
}
|
|
186
|
-
|
|
187
|
-
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
|
|
188
|
-
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
189
|
-
GGML_ASSERT(tensor->ne[1] == 1);
|
|
190
|
-
GGML_ASSERT(tensor->ne[2] == 1);
|
|
191
|
-
GGML_ASSERT(tensor->ne[3] == 1);
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
|
|
195
|
-
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
196
|
-
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
197
|
-
GGML_ASSERT(tensor->ne[2] == 1);
|
|
198
|
-
GGML_ASSERT(tensor->ne[3] == 1);
|
|
199
|
-
}
|
|
200
|
-
|
|
201
|
-
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
|
|
202
|
-
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
203
|
-
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
204
|
-
GGML_ASSERT(tensor->ne[2] == ne2);
|
|
205
|
-
GGML_ASSERT(tensor->ne[3] == 1);
|
|
206
|
-
}
|
|
207
|
-
|
|
208
|
-
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
|
|
209
|
-
GGML_ASSERT(tensor->ne[0] == ne0);
|
|
210
|
-
GGML_ASSERT(tensor->ne[1] == ne1);
|
|
211
|
-
GGML_ASSERT(tensor->ne[2] == ne2);
|
|
212
|
-
GGML_ASSERT(tensor->ne[3] == ne3);
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
int64_t get_example_targets_batch(
|
|
216
|
-
struct llama_context * lctx,
|
|
217
|
-
struct ggml_tensor * tokens_input,
|
|
218
|
-
struct ggml_tensor * target_probs,
|
|
219
|
-
int64_t example_id,
|
|
220
|
-
const size_t * samples_offs,
|
|
221
|
-
const size_t * samples_begin,
|
|
222
|
-
const size_t * samples_size,
|
|
223
|
-
size_t samples_count,
|
|
224
|
-
const llama_token * train_data,
|
|
225
|
-
size_t n_train_data,
|
|
226
|
-
bool separate_with_eos,
|
|
227
|
-
bool separate_with_bos,
|
|
228
|
-
bool fill_with_next_samples,
|
|
229
|
-
bool sample_random_offsets
|
|
230
|
-
) {
|
|
231
|
-
GGML_ASSERT(samples_count > 0);
|
|
232
|
-
GGML_ASSERT(ggml_is_matrix(tokens_input));
|
|
233
|
-
GGML_ASSERT(ggml_is_3d(target_probs));
|
|
234
|
-
int64_t n_vocab = target_probs->ne[0];
|
|
235
|
-
int64_t n_tokens = tokens_input->ne[0];
|
|
236
|
-
int64_t n_batch = tokens_input->ne[1];
|
|
237
|
-
GGML_ASSERT(n_vocab == target_probs->ne[0]);
|
|
238
|
-
GGML_ASSERT(n_tokens == target_probs->ne[1]);
|
|
239
|
-
GGML_ASSERT(n_batch == target_probs->ne[2]);
|
|
240
|
-
|
|
241
|
-
int64_t used_samples = 0;
|
|
242
|
-
|
|
243
|
-
ggml_set_f32(target_probs, 0.0f);
|
|
244
|
-
llama_token bos = llama_token_bos(llama_get_model(lctx));
|
|
245
|
-
llama_token eos = llama_token_eos(llama_get_model(lctx));
|
|
246
|
-
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
|
|
247
|
-
for (int k=0; k<n_batch; ++k) {
|
|
248
|
-
// printf("%s: batch %d\n", __func__, k);
|
|
249
|
-
size_t sample_idx = (example_id + used_samples) % samples_count;
|
|
250
|
-
size_t sample_offs = sample_random_offsets ? samples_offs[sample_idx] : 0;
|
|
251
|
-
size_t sample_begin = samples_begin[sample_idx];
|
|
252
|
-
size_t sample_size = samples_size[sample_idx];
|
|
253
|
-
++used_samples;
|
|
254
|
-
|
|
255
|
-
// printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
|
|
256
|
-
GGML_ASSERT(sample_begin+sample_size-1 < n_train_data);
|
|
257
|
-
|
|
258
|
-
ggml_set_i32_nd(tokens_input, 0, k, 0, 0, bos);
|
|
259
|
-
bool sample_separation_eos = !separate_with_eos;
|
|
260
|
-
bool sample_separation_bos = !separate_with_bos;
|
|
261
|
-
for (int64_t i=0; i<n_tokens; ++i) {
|
|
262
|
-
llama_token token = eos;
|
|
263
|
-
if (sample_offs >= sample_size && fill_with_next_samples) {
|
|
264
|
-
if (!sample_separation_eos) {
|
|
265
|
-
// insert eos token to separate samples
|
|
266
|
-
sample_separation_eos = true;
|
|
267
|
-
} else if (!sample_separation_bos) {
|
|
268
|
-
// insert bos token to separate samples
|
|
269
|
-
sample_separation_bos = true;
|
|
270
|
-
token = bos;
|
|
271
|
-
} else {
|
|
272
|
-
// sample separation is done, continue with next sample
|
|
273
|
-
sample_separation_eos = !separate_with_eos;
|
|
274
|
-
sample_separation_bos = !separate_with_bos;
|
|
275
|
-
sample_offs = 0;
|
|
276
|
-
sample_idx = (example_id + used_samples) % samples_count;
|
|
277
|
-
sample_begin = samples_begin[sample_idx];
|
|
278
|
-
sample_size = samples_size[sample_idx];
|
|
279
|
-
++used_samples;
|
|
280
|
-
}
|
|
281
|
-
}
|
|
282
|
-
// note: no else-if here
|
|
283
|
-
if (sample_offs < sample_size) {
|
|
284
|
-
token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1));
|
|
285
|
-
++sample_offs;
|
|
286
|
-
}
|
|
287
|
-
ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f);
|
|
288
|
-
if (i+1<n_tokens) {
|
|
289
|
-
ggml_set_i32_nd(tokens_input, (int) (i + 1), (int) k, 0, 0, token);
|
|
290
|
-
}
|
|
291
|
-
}
|
|
292
|
-
}
|
|
293
|
-
|
|
294
|
-
return used_samples;
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
void mt19937_set_state(std::mt19937& rng, const std::string& rng_state) {
|
|
298
|
-
std::stringstream s_rng_state;
|
|
299
|
-
s_rng_state.imbue(std::locale::classic());
|
|
300
|
-
s_rng_state.exceptions(std::stringstream::failbit);
|
|
301
|
-
s_rng_state.str(rng_state);
|
|
302
|
-
s_rng_state >> rng;
|
|
303
|
-
}
|
|
304
|
-
|
|
305
|
-
std::string mt19937_get_state(const std::mt19937& rng) {
|
|
306
|
-
std::stringstream s_rng_state;
|
|
307
|
-
s_rng_state.imbue(std::locale::classic());
|
|
308
|
-
s_rng_state << rng;
|
|
309
|
-
return s_rng_state.str();
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
std::string mt19937_seed_to_state(unsigned seed) {
|
|
313
|
-
std::mt19937 rng(seed);
|
|
314
|
-
return mt19937_get_state(rng);
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
std::string shuffle_samples(
|
|
318
|
-
const std::string & rng_state,
|
|
319
|
-
size_t * shuffled_offs,
|
|
320
|
-
size_t * shuffled_begins,
|
|
321
|
-
size_t * shuffled_sizes,
|
|
322
|
-
const size_t * begins,
|
|
323
|
-
const size_t * sizes,
|
|
324
|
-
size_t count) {
|
|
325
|
-
if (count == 0) return rng_state;
|
|
326
|
-
|
|
327
|
-
std::mt19937 rng;
|
|
328
|
-
mt19937_set_state(rng, rng_state);
|
|
329
|
-
|
|
330
|
-
// sort indices by random value for each index
|
|
331
|
-
std::vector<size_t> idcs;
|
|
332
|
-
{
|
|
333
|
-
std::vector<unsigned> rnd;
|
|
334
|
-
idcs.resize(count);
|
|
335
|
-
rnd.resize(count);
|
|
336
|
-
for (unsigned i=0; i<count; ++i) {
|
|
337
|
-
idcs[i] = i;
|
|
338
|
-
rnd[i] = rng();
|
|
339
|
-
}
|
|
340
|
-
|
|
341
|
-
std::sort(idcs.begin(), idcs.end(), [&rnd](size_t a, size_t b){
|
|
342
|
-
// stable sort for reproducibility
|
|
343
|
-
return (rnd[a] == rnd[b]) ? (a < b) : (rnd[a] < rnd[b]);
|
|
344
|
-
});
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
// create random offsets
|
|
348
|
-
for (unsigned i=0; i<count; ++i) {
|
|
349
|
-
shuffled_offs[i] = (size_t) ((sizes[idcs[i]] - 1) * ((double) rng() / (double) (rng.max()-1)));
|
|
350
|
-
}
|
|
351
|
-
|
|
352
|
-
// reorder begins and sizes by sorted indices
|
|
353
|
-
for (unsigned i=0; i<count; ++i) {
|
|
354
|
-
shuffled_begins[i] = begins[idcs[i]];
|
|
355
|
-
}
|
|
356
|
-
|
|
357
|
-
for (unsigned i=0; i<count; ++i) {
|
|
358
|
-
shuffled_sizes[i] = sizes[idcs[i]];
|
|
359
|
-
}
|
|
360
|
-
|
|
361
|
-
return mt19937_get_state(rng);
|
|
362
|
-
}
|
|
363
|
-
|
|
364
|
-
size_t hash_combine(size_t h1, size_t h2) {
|
|
365
|
-
return h1 ^ (h2 << 1);
|
|
366
|
-
}
|
|
367
|
-
|
|
368
|
-
size_t compute_samples_hash(const char* fn, const size_t* samples_begin, const size_t* samples_size, size_t sample_count) {
|
|
369
|
-
std::hash<std::string> h_string;
|
|
370
|
-
std::hash<unsigned long long> h_ull;
|
|
371
|
-
size_t h = h_string(std::string(fn));
|
|
372
|
-
h = hash_combine(h, h_ull((unsigned long long) sample_count));
|
|
373
|
-
for (size_t i=0; i< sample_count; ++i) {
|
|
374
|
-
h = hash_combine(h, h_ull((unsigned long long) samples_begin[i]));
|
|
375
|
-
h = hash_combine(h, h_ull((unsigned long long) samples_size[i]));
|
|
376
|
-
}
|
|
377
|
-
return h;
|
|
378
|
-
}
|
|
379
|
-
|
|
380
|
-
std::string replace_str(const char * s, const char * needle, const char * replacement) {
|
|
381
|
-
std::string str = s;
|
|
382
|
-
size_t pos = str.find(needle);
|
|
383
|
-
if (pos != std::string::npos) {
|
|
384
|
-
str.replace(pos, strlen(needle), replacement);
|
|
385
|
-
}
|
|
386
|
-
return str;
|
|
387
|
-
}
|
|
388
|
-
|
|
389
|
-
void print_duration(double fmillis) {
|
|
390
|
-
if (fmillis < 1000.0f) {
|
|
391
|
-
printf("%.1fms", (float) fmillis);
|
|
392
|
-
return;
|
|
393
|
-
}
|
|
394
|
-
const int64_t one_sec = 1000;
|
|
395
|
-
const int64_t one_min = one_sec * 60;
|
|
396
|
-
const int64_t one_hour = one_min * 60;
|
|
397
|
-
const int64_t one_day = one_hour * 24;
|
|
398
|
-
|
|
399
|
-
int64_t millis = (int64_t) fmillis;
|
|
400
|
-
int64_t days = millis/one_day;
|
|
401
|
-
int64_t hours = (millis - days*one_day)/one_hour;
|
|
402
|
-
int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
|
|
403
|
-
int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
|
|
404
|
-
|
|
405
|
-
// to print int64_t either cast to (long long int) or use macro PRId64 from <inttypes.h>
|
|
406
|
-
if (days > 0) {
|
|
407
|
-
printf("%lldd ", (long long int) days);
|
|
408
|
-
}
|
|
409
|
-
printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds);
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
float cosine_decay(int64_t step, int64_t decay_steps, float minimum) {
|
|
413
|
-
if (step > decay_steps) {
|
|
414
|
-
step = decay_steps;
|
|
415
|
-
}
|
|
416
|
-
const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps));
|
|
417
|
-
const float decay = (1 - minimum)*cosine_decay + minimum;
|
|
418
|
-
return decay;
|
|
419
|
-
}
|
|
420
|
-
|
|
421
|
-
float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) {
|
|
422
|
-
while (step > decay_steps) {
|
|
423
|
-
step -= decay_steps;
|
|
424
|
-
decay_steps = (int64_t) (restart_step_mult * decay_steps);
|
|
425
|
-
}
|
|
426
|
-
return cosine_decay(step, decay_steps, minimum);
|
|
427
|
-
}
|
|
428
|
-
|
|
429
|
-
float learning_schedule(
|
|
430
|
-
int64_t step,
|
|
431
|
-
int64_t warmup_steps,
|
|
432
|
-
int64_t cos_decay_steps,
|
|
433
|
-
float learning_rate,
|
|
434
|
-
float overall_minimum,
|
|
435
|
-
float cos_decay_minimum,
|
|
436
|
-
float cos_decay_restart_step_mult,
|
|
437
|
-
bool enable_restart) {
|
|
438
|
-
|
|
439
|
-
float result =
|
|
440
|
-
(step < warmup_steps)
|
|
441
|
-
? (float) step / (float) warmup_steps
|
|
442
|
-
: enable_restart
|
|
443
|
-
? cosine_decay_restart(
|
|
444
|
-
step - warmup_steps,
|
|
445
|
-
cos_decay_steps,
|
|
446
|
-
cos_decay_minimum,
|
|
447
|
-
cos_decay_restart_step_mult)
|
|
448
|
-
: cosine_decay(
|
|
449
|
-
step,
|
|
450
|
-
cos_decay_steps,
|
|
451
|
-
cos_decay_minimum);
|
|
452
|
-
|
|
453
|
-
float min = overall_minimum / learning_rate;
|
|
454
|
-
result = min + result * (1.0f - min);
|
|
455
|
-
return result;
|
|
456
|
-
}
|
|
457
|
-
|
|
458
|
-
static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) {
|
|
459
|
-
GGML_ASSERT(a != NULL);
|
|
460
|
-
GGML_ASSERT(b != NULL);
|
|
461
|
-
GGML_ASSERT(a->type == b->type);
|
|
462
|
-
GGML_ASSERT(ggml_are_same_shape(a, b));
|
|
463
|
-
GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b));
|
|
464
|
-
|
|
465
|
-
return true;
|
|
466
|
-
}
|
|
467
|
-
|
|
468
|
-
void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) {
|
|
469
|
-
if (dst == NULL) {
|
|
470
|
-
return;
|
|
471
|
-
}
|
|
472
|
-
struct ggml_tensor * t = ggml_get_tensor(ctx, name);
|
|
473
|
-
GGML_ASSERT(are_same_layout(dst, t));
|
|
474
|
-
memcpy(dst->data, t->data, ggml_nbytes(t));
|
|
475
|
-
|
|
476
|
-
if (strlen(ggml_get_name(dst)) == 0) {
|
|
477
|
-
ggml_set_name(dst, name);
|
|
478
|
-
}
|
|
479
|
-
}
|
|
480
|
-
|
|
481
|
-
// gguf constants
|
|
482
|
-
static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type";
|
|
483
|
-
static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam";
|
|
484
|
-
static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs";
|
|
485
|
-
static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version";
|
|
486
|
-
static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count";
|
|
487
|
-
static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count";
|
|
488
|
-
static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
|
|
489
|
-
static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
|
|
490
|
-
static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
|
|
491
|
-
static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
|
|
492
|
-
static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
|
|
493
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count";
|
|
494
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss";
|
|
495
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step";
|
|
496
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j";
|
|
497
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k";
|
|
498
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end";
|
|
499
|
-
static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count";
|
|
500
|
-
|
|
501
|
-
static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments";
|
|
502
|
-
static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments";
|
|
503
|
-
static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values";
|
|
504
|
-
|
|
505
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters";
|
|
506
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters";
|
|
507
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients";
|
|
508
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients";
|
|
509
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction";
|
|
510
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values";
|
|
511
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha";
|
|
512
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys";
|
|
513
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
|
|
514
|
-
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
|
|
515
|
-
|
|
516
|
-
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
|
|
517
|
-
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
|
|
518
|
-
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
|
|
519
|
-
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
|
|
520
|
-
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
|
|
521
|
-
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
|
|
522
|
-
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
|
|
523
|
-
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
|
|
524
|
-
static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
|
|
525
|
-
|
|
526
|
-
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
|
|
527
|
-
{ \
|
|
528
|
-
const std::string skey(key); \
|
|
529
|
-
const int kid = gguf_find_key(ctx, skey.c_str()); \
|
|
530
|
-
if (kid >= 0) { \
|
|
531
|
-
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
|
|
532
|
-
if (ktype != (type)) { \
|
|
533
|
-
die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
|
|
534
|
-
} \
|
|
535
|
-
(dst) = func(ctx, kid); \
|
|
536
|
-
} else if (req) { \
|
|
537
|
-
die_fmt("key not found in model: %s", skey.c_str()); \
|
|
538
|
-
} \
|
|
539
|
-
}
|
|
540
|
-
|
|
541
|
-
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
|
|
542
|
-
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
|
|
543
|
-
|
|
544
|
-
uint32_t file_version;
|
|
545
|
-
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION);
|
|
546
|
-
GGML_ASSERT(file_version == 0);
|
|
547
|
-
|
|
548
|
-
GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
|
|
549
|
-
GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
|
|
550
|
-
GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
|
|
551
|
-
|
|
552
|
-
uint64_t nx;
|
|
553
|
-
GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT);
|
|
554
|
-
opt->nx = (size_t) nx;
|
|
555
|
-
|
|
556
|
-
// don't call ggml_opt_init until optimizer type and optimizer specific parameters are know
|
|
557
|
-
|
|
558
|
-
std::string opt_type;
|
|
559
|
-
GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE);
|
|
560
|
-
if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) {
|
|
561
|
-
opt->params.type = GGML_OPT_TYPE_ADAM;
|
|
562
|
-
|
|
563
|
-
GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS);
|
|
564
|
-
GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS);
|
|
565
|
-
GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT);
|
|
566
|
-
|
|
567
|
-
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
|
|
568
|
-
|
|
569
|
-
copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
|
|
570
|
-
copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
|
|
571
|
-
copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
|
|
572
|
-
} else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) {
|
|
573
|
-
opt->params.type = GGML_OPT_TYPE_LBFGS;
|
|
574
|
-
|
|
575
|
-
GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT);
|
|
576
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS);
|
|
577
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP);
|
|
578
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J);
|
|
579
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K);
|
|
580
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END);
|
|
581
|
-
GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT);
|
|
582
|
-
|
|
583
|
-
ggml_opt_init(opt->ctx, opt, opt->params, opt->nx);
|
|
584
|
-
|
|
585
|
-
copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
|
|
586
|
-
copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
|
|
587
|
-
copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
|
|
588
|
-
copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
|
|
589
|
-
copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
|
|
590
|
-
copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
|
|
591
|
-
copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
|
|
592
|
-
copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
|
|
593
|
-
copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
|
|
594
|
-
copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
|
|
595
|
-
} else {
|
|
596
|
-
die("unknown optimizer type\n");
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
|
|
600
|
-
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) {
|
|
601
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0);
|
|
602
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past);
|
|
603
|
-
gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx);
|
|
604
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
|
|
605
|
-
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
|
|
606
|
-
|
|
607
|
-
switch (opt->params.type) {
|
|
608
|
-
case GGML_OPT_TYPE_ADAM:
|
|
609
|
-
{
|
|
610
|
-
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM);
|
|
611
|
-
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best);
|
|
612
|
-
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev);
|
|
613
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement);
|
|
614
|
-
|
|
615
|
-
ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS);
|
|
616
|
-
ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS);
|
|
617
|
-
if (opt->adam.pf) {
|
|
618
|
-
ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES);
|
|
619
|
-
}
|
|
620
|
-
|
|
621
|
-
gguf_add_tensor(fctx, opt->adam.m);
|
|
622
|
-
gguf_add_tensor(fctx, opt->adam.v);
|
|
623
|
-
if (opt->adam.pf) {
|
|
624
|
-
gguf_add_tensor(fctx, opt->adam.pf);
|
|
625
|
-
}
|
|
626
|
-
} break;
|
|
627
|
-
case GGML_OPT_TYPE_LBFGS:
|
|
628
|
-
{
|
|
629
|
-
gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS);
|
|
630
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m);
|
|
631
|
-
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best);
|
|
632
|
-
gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step);
|
|
633
|
-
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j);
|
|
634
|
-
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k);
|
|
635
|
-
gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end);
|
|
636
|
-
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement);
|
|
637
|
-
|
|
638
|
-
ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS);
|
|
639
|
-
ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS);
|
|
640
|
-
ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS);
|
|
641
|
-
ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS);
|
|
642
|
-
ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION);
|
|
643
|
-
if (opt->lbfgs.pf) {
|
|
644
|
-
ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES);
|
|
645
|
-
}
|
|
646
|
-
ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA);
|
|
647
|
-
ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS);
|
|
648
|
-
ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S);
|
|
649
|
-
ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y);
|
|
650
|
-
|
|
651
|
-
gguf_add_tensor(fctx, opt->lbfgs.x);
|
|
652
|
-
gguf_add_tensor(fctx, opt->lbfgs.xp);
|
|
653
|
-
gguf_add_tensor(fctx, opt->lbfgs.g);
|
|
654
|
-
gguf_add_tensor(fctx, opt->lbfgs.gp);
|
|
655
|
-
gguf_add_tensor(fctx, opt->lbfgs.d);
|
|
656
|
-
if (opt->lbfgs.pf) {
|
|
657
|
-
gguf_add_tensor(fctx, opt->lbfgs.pf);
|
|
658
|
-
}
|
|
659
|
-
gguf_add_tensor(fctx, opt->lbfgs.lmal);
|
|
660
|
-
gguf_add_tensor(fctx, opt->lbfgs.lmys);
|
|
661
|
-
gguf_add_tensor(fctx, opt->lbfgs.lms);
|
|
662
|
-
gguf_add_tensor(fctx, opt->lbfgs.lmy);
|
|
663
|
-
} break;
|
|
664
|
-
}
|
|
665
|
-
}
|
|
666
|
-
|
|
667
|
-
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
|
|
668
|
-
if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) {
|
|
669
|
-
return false;
|
|
670
|
-
}
|
|
671
|
-
|
|
672
|
-
uint32_t file_version;
|
|
673
|
-
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
|
|
674
|
-
GGML_ASSERT(file_version <= 1);
|
|
675
|
-
|
|
676
|
-
if (file_version == 0) {
|
|
677
|
-
|
|
678
|
-
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
|
679
|
-
GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
|
|
680
|
-
GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
|
|
681
|
-
|
|
682
|
-
} else if (file_version == 1) {
|
|
683
|
-
|
|
684
|
-
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
|
685
|
-
GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
|
|
686
|
-
GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
|
|
687
|
-
GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
|
|
688
|
-
|
|
689
|
-
GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
|
|
690
|
-
GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
|
|
691
|
-
GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
|
|
692
|
-
GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
|
|
693
|
-
}
|
|
694
|
-
|
|
695
|
-
load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
|
|
696
|
-
return true;
|
|
697
|
-
}
|
|
698
|
-
|
|
699
|
-
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
|
|
700
|
-
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
|
|
701
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
|
|
702
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
|
|
703
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
|
|
704
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
|
|
705
|
-
|
|
706
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
|
|
707
|
-
gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
|
|
708
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
|
|
709
|
-
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
|
|
710
|
-
|
|
711
|
-
save_opt_context_gguf(fctx, train->opt);
|
|
712
|
-
}
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
struct llama_file {
|
|
716
|
-
// use FILE * so we don't have to re-open the file to mmap
|
|
717
|
-
FILE * fp;
|
|
718
|
-
size_t size;
|
|
719
|
-
|
|
720
|
-
llama_file(const char * fname, const char * mode) {
|
|
721
|
-
fp = std::fopen(fname, mode);
|
|
722
|
-
if (fp == NULL) {
|
|
723
|
-
size = 0;
|
|
724
|
-
} else {
|
|
725
|
-
seek(0, SEEK_END);
|
|
726
|
-
size = tell();
|
|
727
|
-
seek(0, SEEK_SET);
|
|
728
|
-
}
|
|
729
|
-
}
|
|
730
|
-
|
|
731
|
-
size_t tell() const {
|
|
732
|
-
#ifdef _WIN32
|
|
733
|
-
__int64 ret = _ftelli64(fp);
|
|
734
|
-
#else
|
|
735
|
-
long ret = std::ftell(fp);
|
|
736
|
-
#endif
|
|
737
|
-
GGML_ASSERT(ret != -1); // this really shouldn't fail
|
|
738
|
-
return (size_t) ret;
|
|
739
|
-
}
|
|
740
|
-
|
|
741
|
-
void seek(size_t offset, int whence) {
|
|
742
|
-
#ifdef _WIN32
|
|
743
|
-
int ret = _fseeki64(fp, (__int64) offset, whence);
|
|
744
|
-
#else
|
|
745
|
-
int ret = std::fseek(fp, (long) offset, whence);
|
|
746
|
-
#endif
|
|
747
|
-
GGML_ASSERT(ret == 0); // same
|
|
748
|
-
}
|
|
749
|
-
|
|
750
|
-
void read_raw(void * ptr, size_t size) {
|
|
751
|
-
if (size == 0) {
|
|
752
|
-
return;
|
|
753
|
-
}
|
|
754
|
-
errno = 0;
|
|
755
|
-
std::size_t ret = std::fread(ptr, size, 1, fp);
|
|
756
|
-
if (ferror(fp)) {
|
|
757
|
-
die_fmt("read error: %s", strerror(errno));
|
|
758
|
-
}
|
|
759
|
-
if (ret != 1) {
|
|
760
|
-
die("unexpectedly reached end of file");
|
|
761
|
-
}
|
|
762
|
-
}
|
|
763
|
-
|
|
764
|
-
std::uint32_t read_u32() {
|
|
765
|
-
std::uint32_t ret;
|
|
766
|
-
read_raw(&ret, sizeof(ret));
|
|
767
|
-
return ret;
|
|
768
|
-
}
|
|
769
|
-
|
|
770
|
-
std::string read_string(std::uint32_t len) {
|
|
771
|
-
std::vector<char> chars(len);
|
|
772
|
-
read_raw(chars.data(), len);
|
|
773
|
-
return std::string(chars.data(), len);
|
|
774
|
-
}
|
|
775
|
-
|
|
776
|
-
void write_raw(const void * ptr, size_t size) {
|
|
777
|
-
if (size == 0) {
|
|
778
|
-
return;
|
|
779
|
-
}
|
|
780
|
-
errno = 0;
|
|
781
|
-
size_t ret = std::fwrite(ptr, size, 1, fp);
|
|
782
|
-
if (ret != 1) {
|
|
783
|
-
die_fmt("write error: %s", strerror(errno));
|
|
784
|
-
}
|
|
785
|
-
}
|
|
786
|
-
|
|
787
|
-
void write_u32(std::uint32_t val) {
|
|
788
|
-
write_raw(&val, sizeof(val));
|
|
789
|
-
}
|
|
790
|
-
|
|
791
|
-
~llama_file() {
|
|
792
|
-
if (fp) {
|
|
793
|
-
std::fclose(fp);
|
|
794
|
-
}
|
|
795
|
-
}
|
|
796
|
-
};
|
|
797
|
-
|
|
798
|
-
static size_t utf8_len(char src) {
|
|
799
|
-
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
800
|
-
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
|
801
|
-
return lookup[highbits];
|
|
802
|
-
}
|
|
803
|
-
|
|
804
|
-
// mark each byte with its utf8 unit number.
|
|
805
|
-
// returns the number of utf8 characters.
|
|
806
|
-
// e.g. when bytes == '\x61\xD0\xB0\x62',
|
|
807
|
-
// then utf8_units will become [0,0,1,0]
|
|
808
|
-
// utf8_nunits will become [1,2,2,1] and 3 is returned.
|
|
809
|
-
// bytes where utf8_units is zero, are the begin of an utf8 character.
|
|
810
|
-
static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) {
|
|
811
|
-
size_t offs = 0;
|
|
812
|
-
size_t count_utf8 = 0;
|
|
813
|
-
while(offs < count) {
|
|
814
|
-
int len = (int) utf8_len(bytes[offs]);
|
|
815
|
-
for (int i=0; i<len; ++i) {
|
|
816
|
-
utf8_units[offs+i] = i;
|
|
817
|
-
utf8_nunits[offs+i] = len;
|
|
818
|
-
}
|
|
819
|
-
offs += len;
|
|
820
|
-
++count_utf8;
|
|
821
|
-
}
|
|
822
|
-
return count_utf8;
|
|
823
|
-
}
|
|
824
|
-
|
|
825
|
-
size_t tokenize_file(
|
|
826
|
-
struct llama_context * lctx,
|
|
827
|
-
const char * filename,
|
|
828
|
-
const std::string & sample_start,
|
|
829
|
-
bool include_sample_start,
|
|
830
|
-
bool overlapping_samples,
|
|
831
|
-
unsigned context_length,
|
|
832
|
-
std::vector<llama_token> & out_tokens,
|
|
833
|
-
std::vector<size_t> & out_samples_begin,
|
|
834
|
-
std::vector<size_t> & out_samples_size) {
|
|
835
|
-
struct llama_file f(filename, "rb");
|
|
836
|
-
|
|
837
|
-
if (f.size == 0) {
|
|
838
|
-
out_tokens.clear();
|
|
839
|
-
out_samples_begin.clear();
|
|
840
|
-
out_samples_size.clear();
|
|
841
|
-
printf("%s: warning: empty or not existing training data file '%s'\n",
|
|
842
|
-
__func__, filename);
|
|
843
|
-
return out_tokens.size();
|
|
844
|
-
}
|
|
845
|
-
|
|
846
|
-
// account for possible leading whitespace that will be added by tokenizer
|
|
847
|
-
// e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12]
|
|
848
|
-
const int n_max_tokens_overhead = 1;
|
|
849
|
-
|
|
850
|
-
std::vector<char> buf;
|
|
851
|
-
buf.resize(f.size);
|
|
852
|
-
|
|
853
|
-
f.read_raw(buf.data(), f.size);
|
|
854
|
-
|
|
855
|
-
std::vector<int> utf8_units;
|
|
856
|
-
std::vector<int> utf8_nunits;
|
|
857
|
-
utf8_units.resize(buf.size());
|
|
858
|
-
utf8_nunits.resize(buf.size());
|
|
859
|
-
mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size());
|
|
860
|
-
|
|
861
|
-
if (sample_start.size() == 0) {
|
|
862
|
-
// tokenize all data at once
|
|
863
|
-
out_tokens.resize(buf.size() + n_max_tokens_overhead);
|
|
864
|
-
|
|
865
|
-
int n_tokens = llama_tokenize(
|
|
866
|
-
llama_get_model(lctx),
|
|
867
|
-
buf.data(),
|
|
868
|
-
(int) buf.size(),
|
|
869
|
-
out_tokens.data(),
|
|
870
|
-
(int) out_tokens.size(),
|
|
871
|
-
false, false);
|
|
872
|
-
if (n_tokens < 0) {
|
|
873
|
-
out_tokens.resize(-n_tokens);
|
|
874
|
-
n_tokens = llama_tokenize(
|
|
875
|
-
llama_get_model(lctx),
|
|
876
|
-
buf.data(),
|
|
877
|
-
(int) buf.size(),
|
|
878
|
-
out_tokens.data(),
|
|
879
|
-
(int) out_tokens.size(),
|
|
880
|
-
false, false);
|
|
881
|
-
}
|
|
882
|
-
if (n_tokens >= 0) {
|
|
883
|
-
out_tokens.resize(n_tokens);
|
|
884
|
-
}
|
|
885
|
-
|
|
886
|
-
// generate sample starts at all token positions
|
|
887
|
-
out_samples_begin.clear();
|
|
888
|
-
out_samples_begin.push_back(0);
|
|
889
|
-
out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size()));
|
|
890
|
-
size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0;
|
|
891
|
-
for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) {
|
|
892
|
-
out_samples_begin.push_back(sample_begin);
|
|
893
|
-
out_samples_size.push_back(context_length);
|
|
894
|
-
}
|
|
895
|
-
} else {
|
|
896
|
-
// split data into samples and tokenize each sample
|
|
897
|
-
std::string data_str(buf.data(), buf.size());
|
|
898
|
-
out_samples_begin.clear();
|
|
899
|
-
out_samples_size.clear();
|
|
900
|
-
out_tokens.clear();
|
|
901
|
-
|
|
902
|
-
// find all positions of pattern sample_start
|
|
903
|
-
size_t sample_begin = data_str.find(sample_start, 0);
|
|
904
|
-
while (sample_begin != std::string::npos) {
|
|
905
|
-
out_samples_begin.push_back(sample_begin);
|
|
906
|
-
const size_t search_start = sample_begin + sample_start.size();
|
|
907
|
-
sample_begin = data_str.find(sample_start, search_start);
|
|
908
|
-
}
|
|
909
|
-
if (out_samples_begin.size() == 0) {
|
|
910
|
-
printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n",
|
|
911
|
-
__func__, sample_start.c_str());
|
|
912
|
-
out_samples_begin.push_back(0);
|
|
913
|
-
}
|
|
914
|
-
|
|
915
|
-
out_samples_size.resize(out_samples_begin.size(), 0);
|
|
916
|
-
|
|
917
|
-
std::vector<char> buf_sample;
|
|
918
|
-
std::vector<llama_token> tok_sample;
|
|
919
|
-
|
|
920
|
-
const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size());
|
|
921
|
-
size_t found_too_big_sample = 0;
|
|
922
|
-
size_t found_too_small_sample = 0;
|
|
923
|
-
size_t found_empty_sample = 0;
|
|
924
|
-
size_t found_min_sample_size = SIZE_MAX;
|
|
925
|
-
size_t found_max_sample_size = 0;
|
|
926
|
-
|
|
927
|
-
size_t max_token_text_size = 0;
|
|
928
|
-
int n_vocab = llama_n_vocab(llama_get_model(lctx));
|
|
929
|
-
for (llama_token token=0; token < n_vocab; ++token) {
|
|
930
|
-
max_token_text_size = std::max(
|
|
931
|
-
max_token_text_size,
|
|
932
|
-
strlen(llama_token_get_text(llama_get_model(lctx), token)));
|
|
933
|
-
}
|
|
934
|
-
|
|
935
|
-
// upper bound of context byte length.
|
|
936
|
-
// strings with this byte length should always tokenize to at least context_length tokens.
|
|
937
|
-
size_t context_byte_len = max_token_text_size*context_length;
|
|
938
|
-
|
|
939
|
-
for (unsigned i=0; i<out_samples_begin.size(); ++i) {
|
|
940
|
-
// determine sample begin and end from pattern positions
|
|
941
|
-
size_t sample_begin = out_samples_begin[i] + sample_begin_offset;
|
|
942
|
-
size_t sample_end = overlapping_samples
|
|
943
|
-
? std::min(
|
|
944
|
-
data_str.size(),
|
|
945
|
-
sample_begin + context_byte_len)
|
|
946
|
-
: (i+1 < out_samples_begin.size()
|
|
947
|
-
? out_samples_begin[i+1]
|
|
948
|
-
: data_str.size());
|
|
949
|
-
if (sample_end < utf8_units.size() && utf8_units[sample_end] > 0) {
|
|
950
|
-
// sample end is in the middle of an utf8 character.
|
|
951
|
-
// advance sample_end to the begin of the next utf8 character.
|
|
952
|
-
sample_end += utf8_nunits[sample_end] - utf8_units[sample_end];
|
|
953
|
-
}
|
|
954
|
-
size_t sample_size = sample_end - sample_begin;
|
|
955
|
-
if (sample_size == 0) {
|
|
956
|
-
++found_empty_sample;
|
|
957
|
-
}
|
|
958
|
-
|
|
959
|
-
if (sample_size > 0) {
|
|
960
|
-
// llama_tokenize expects zero terminated string,
|
|
961
|
-
// copy sample into buffer and zero terminate it.
|
|
962
|
-
buf_sample.resize(sample_size);
|
|
963
|
-
memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size);
|
|
964
|
-
|
|
965
|
-
// printf("sample: '%s'\n", buf_sample.data());
|
|
966
|
-
|
|
967
|
-
// tokenize the sample
|
|
968
|
-
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
|
|
969
|
-
int n_tokens = llama_tokenize(llama_get_model(lctx),
|
|
970
|
-
buf_sample.data(),
|
|
971
|
-
(int) buf_sample.size(),
|
|
972
|
-
tok_sample.data(),
|
|
973
|
-
(int) tok_sample.size(),
|
|
974
|
-
false, false);
|
|
975
|
-
if (n_tokens < 0) {
|
|
976
|
-
tok_sample.resize(-n_tokens);
|
|
977
|
-
n_tokens = llama_tokenize(llama_get_model(lctx),
|
|
978
|
-
buf_sample.data(),
|
|
979
|
-
(int) buf_sample.size(),
|
|
980
|
-
tok_sample.data(),
|
|
981
|
-
(int) tok_sample.size(),
|
|
982
|
-
false, false);
|
|
983
|
-
GGML_ASSERT(n_tokens >= 0);
|
|
984
|
-
}
|
|
985
|
-
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
|
|
986
|
-
|
|
987
|
-
if ((size_t) n_tokens > context_length) {
|
|
988
|
-
++found_too_big_sample;
|
|
989
|
-
} else if ((size_t) n_tokens < context_length) {
|
|
990
|
-
++found_too_small_sample;
|
|
991
|
-
}
|
|
992
|
-
found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens);
|
|
993
|
-
found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens);
|
|
994
|
-
|
|
995
|
-
// write out tokens, start and size of sample
|
|
996
|
-
// overwrite the string start position with the token start position
|
|
997
|
-
out_samples_begin[i] = out_tokens.size();
|
|
998
|
-
out_samples_size[i] = (size_t) n_tokens;
|
|
999
|
-
out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens);
|
|
1000
|
-
} else {
|
|
1001
|
-
out_samples_begin[i] = out_tokens.size();
|
|
1002
|
-
out_samples_size[i] = 0;
|
|
1003
|
-
}
|
|
1004
|
-
|
|
1005
|
-
}
|
|
1006
|
-
if (found_too_big_sample > 0) {
|
|
1007
|
-
printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n",
|
|
1008
|
-
__func__, found_too_big_sample, found_max_sample_size, context_length);
|
|
1009
|
-
}
|
|
1010
|
-
|
|
1011
|
-
if (found_too_small_sample > 0) {
|
|
1012
|
-
printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n",
|
|
1013
|
-
__func__, found_too_small_sample, found_min_sample_size, context_length);
|
|
1014
|
-
}
|
|
1015
|
-
|
|
1016
|
-
if (found_empty_sample) {
|
|
1017
|
-
printf("%s: warning: found %zu empty samples.\n",
|
|
1018
|
-
__func__, found_empty_sample);
|
|
1019
|
-
}
|
|
1020
|
-
}
|
|
1021
|
-
printf("%s: total number of samples: %zu\n",
|
|
1022
|
-
__func__, out_samples_begin.size());
|
|
1023
|
-
|
|
1024
|
-
GGML_ASSERT(out_samples_begin.size() == out_samples_size.size());
|
|
1025
|
-
|
|
1026
|
-
return out_tokens.size();
|
|
1027
|
-
}
|
|
1028
|
-
|
|
1029
|
-
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) {
|
|
1030
|
-
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
|
|
1031
|
-
return replace_str(filename, pattern_it, sit.c_str());
|
|
1032
|
-
}
|
|
1033
|
-
|
|
1034
|
-
struct train_params_common get_default_train_params_common() {
|
|
1035
|
-
struct train_params_common params;
|
|
1036
|
-
params.fn_train_data = "shakespeare.txt";
|
|
1037
|
-
params.fn_checkpoint_in = "checkpoint.gguf";
|
|
1038
|
-
params.fn_checkpoint_out = "checkpoint-ITERATION.gguf";
|
|
1039
|
-
params.pattern_fn_it = "ITERATION";
|
|
1040
|
-
params.fn_latest = "LATEST";
|
|
1041
|
-
|
|
1042
|
-
params.print_usage = false;
|
|
1043
|
-
|
|
1044
|
-
params.save_every = 10;
|
|
1045
|
-
|
|
1046
|
-
params.seed = -1;
|
|
1047
|
-
|
|
1048
|
-
params.n_ctx = 128;
|
|
1049
|
-
params.n_threads = 6;
|
|
1050
|
-
params.n_batch = 8;
|
|
1051
|
-
params.n_gradient_accumulation = 1;
|
|
1052
|
-
params.n_epochs = -1;
|
|
1053
|
-
params.n_gpu_layers = 0;
|
|
1054
|
-
|
|
1055
|
-
params.custom_n_ctx = false;
|
|
1056
|
-
|
|
1057
|
-
params.use_flash = false;
|
|
1058
|
-
params.use_checkpointing = true;
|
|
1059
|
-
|
|
1060
|
-
params.sample_start = "";
|
|
1061
|
-
params.include_sample_start = false;
|
|
1062
|
-
params.escape = false;
|
|
1063
|
-
params.overlapping_samples = false;
|
|
1064
|
-
params.fill_with_next_samples = false;
|
|
1065
|
-
params.separate_with_eos = false;
|
|
1066
|
-
params.separate_with_bos = true;
|
|
1067
|
-
params.sample_random_offsets = false;
|
|
1068
|
-
params.force_reshuffle = false;
|
|
1069
|
-
|
|
1070
|
-
params.opt_past = 0;
|
|
1071
|
-
params.opt_delta = 1e-5f;
|
|
1072
|
-
params.opt_max_no_improvement = 0;
|
|
1073
|
-
|
|
1074
|
-
params.warmup = 100;
|
|
1075
|
-
params.cos_decay_steps = 1000;
|
|
1076
|
-
params.cos_decay_restart = 1.1f;
|
|
1077
|
-
params.cos_decay_min = 0.1f;
|
|
1078
|
-
params.enable_restart = false;
|
|
1079
|
-
|
|
1080
|
-
params.adam_n_iter = 256;
|
|
1081
|
-
params.adam_alpha = 1e-3f;
|
|
1082
|
-
params.adam_min_alpha = 0;
|
|
1083
|
-
params.adam_decay = 1e-1f;
|
|
1084
|
-
params.adam_decay_min_ndim = 2;
|
|
1085
|
-
params.adam_beta1 = 0.9f;
|
|
1086
|
-
params.adam_beta2 = 0.999f;
|
|
1087
|
-
params.adam_gclip = 1.0f;
|
|
1088
|
-
params.adam_eps_f = 0.0f;
|
|
1089
|
-
|
|
1090
|
-
return params;
|
|
1091
|
-
}
|
|
1092
|
-
|
|
1093
|
-
void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) {
|
|
1094
|
-
// fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
|
1095
|
-
// fprintf(stderr, "\n");
|
|
1096
|
-
// fprintf(stderr, "options:\n");
|
|
1097
|
-
// fprintf(stderr, " -h, --help show this help message and exit\n");
|
|
1098
|
-
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
|
1099
|
-
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
|
1100
|
-
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
|
1101
|
-
fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it);
|
|
1102
|
-
fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest);
|
|
1103
|
-
fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every);
|
|
1104
|
-
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n");
|
|
1105
|
-
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
|
1106
|
-
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
|
1107
|
-
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
|
1108
|
-
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
|
1109
|
-
fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str());
|
|
1110
|
-
fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n");
|
|
1111
|
-
fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
|
1112
|
-
fprintf(stderr, " --overlapping-samples Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n");
|
|
1113
|
-
fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n");
|
|
1114
|
-
fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : "");
|
|
1115
|
-
fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : "");
|
|
1116
|
-
fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : "");
|
|
1117
|
-
fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : "");
|
|
1118
|
-
fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : "");
|
|
1119
|
-
fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n");
|
|
1120
|
-
fprintf(stderr, " --no-flash Don't use flash attention \n");
|
|
1121
|
-
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
|
1122
|
-
fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n");
|
|
1123
|
-
fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n");
|
|
1124
|
-
fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup);
|
|
1125
|
-
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
|
1126
|
-
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
|
1127
|
-
fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min);
|
|
1128
|
-
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
|
1129
|
-
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
|
1130
|
-
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
|
1131
|
-
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
|
1132
|
-
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
|
1133
|
-
fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
|
|
1134
|
-
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
|
1135
|
-
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
|
1136
|
-
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
|
|
1137
|
-
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
|
1138
|
-
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
|
1139
|
-
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
|
1140
|
-
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
|
1141
|
-
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
|
1142
|
-
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
|
1143
|
-
fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
|
|
1144
|
-
fprintf(stderr, "\n");
|
|
1145
|
-
}
|
|
1146
|
-
|
|
1147
|
-
bool consume_common_train_arg(
|
|
1148
|
-
int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param
|
|
1149
|
-
) {
|
|
1150
|
-
int& i = *idx;
|
|
1151
|
-
std::string arg = argv[i];
|
|
1152
|
-
const std::string arg_prefix = "--";
|
|
1153
|
-
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
|
1154
|
-
std::replace(arg.begin(), arg.end(), '_', '-');
|
|
1155
|
-
}
|
|
1156
|
-
if (arg == "--train-data") {
|
|
1157
|
-
if (++i >= argc) {
|
|
1158
|
-
*invalid_param = true;
|
|
1159
|
-
return true;
|
|
1160
|
-
}
|
|
1161
|
-
params->fn_train_data = argv[i];
|
|
1162
|
-
} else if (arg == "--checkpoint-in") {
|
|
1163
|
-
if (++i >= argc) {
|
|
1164
|
-
*invalid_param = true;
|
|
1165
|
-
return true;
|
|
1166
|
-
}
|
|
1167
|
-
params->fn_checkpoint_in = argv[i];
|
|
1168
|
-
} else if (arg == "--checkpoint-out") {
|
|
1169
|
-
if (++i >= argc) {
|
|
1170
|
-
*invalid_param = true;
|
|
1171
|
-
return true;
|
|
1172
|
-
}
|
|
1173
|
-
params->fn_checkpoint_out = argv[i];
|
|
1174
|
-
} else if (arg == "--pattern-fn-it") {
|
|
1175
|
-
if (++i >= argc) {
|
|
1176
|
-
*invalid_param = true;
|
|
1177
|
-
return true;
|
|
1178
|
-
}
|
|
1179
|
-
params->pattern_fn_it = argv[i];
|
|
1180
|
-
} else if (arg == "--fn-latest") {
|
|
1181
|
-
if (++i >= argc) {
|
|
1182
|
-
*invalid_param = true;
|
|
1183
|
-
return true;
|
|
1184
|
-
}
|
|
1185
|
-
params->fn_latest = argv[i];
|
|
1186
|
-
} else if (arg == "--save-every") {
|
|
1187
|
-
if (++i >= argc) {
|
|
1188
|
-
*invalid_param = true;
|
|
1189
|
-
return true;
|
|
1190
|
-
}
|
|
1191
|
-
params->save_every = std::stoi(argv[i]);
|
|
1192
|
-
} else if (arg == "-s" || arg == "--seed") {
|
|
1193
|
-
if (++i >= argc) {
|
|
1194
|
-
*invalid_param = true;
|
|
1195
|
-
return true;
|
|
1196
|
-
}
|
|
1197
|
-
params->seed = std::stoi(argv[i]);
|
|
1198
|
-
} else if (arg == "-c" || arg == "--ctx") {
|
|
1199
|
-
if (++i >= argc) {
|
|
1200
|
-
*invalid_param = true;
|
|
1201
|
-
return true;
|
|
1202
|
-
}
|
|
1203
|
-
params->n_ctx = std::stoi(argv[i]);
|
|
1204
|
-
params->custom_n_ctx = true;
|
|
1205
|
-
} else if (arg == "-t" || arg == "--threads") {
|
|
1206
|
-
if (++i >= argc) {
|
|
1207
|
-
*invalid_param = true;
|
|
1208
|
-
return true;
|
|
1209
|
-
}
|
|
1210
|
-
params->n_threads = std::stoi(argv[i]);
|
|
1211
|
-
} else if (arg == "-b" || arg == "--batch") {
|
|
1212
|
-
if (++i >= argc) {
|
|
1213
|
-
*invalid_param = true;
|
|
1214
|
-
return true;
|
|
1215
|
-
}
|
|
1216
|
-
params->n_batch = std::stoi(argv[i]);
|
|
1217
|
-
} else if (arg == "--grad-acc") {
|
|
1218
|
-
if (++i >= argc) {
|
|
1219
|
-
*invalid_param = true;
|
|
1220
|
-
return true;
|
|
1221
|
-
}
|
|
1222
|
-
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
|
1223
|
-
} else if (arg == "--sample-start") {
|
|
1224
|
-
if (++i >= argc) {
|
|
1225
|
-
*invalid_param = true;
|
|
1226
|
-
return true;
|
|
1227
|
-
}
|
|
1228
|
-
params->sample_start = std::string(argv[i]);
|
|
1229
|
-
} else if (arg == "--escape") {
|
|
1230
|
-
params->escape = true;
|
|
1231
|
-
} else if (arg == "--include-sample-start") {
|
|
1232
|
-
params->include_sample_start = true;
|
|
1233
|
-
} else if (arg == "--overlapping-samples") {
|
|
1234
|
-
params->overlapping_samples = true;
|
|
1235
|
-
} else if (arg == "--fill-with-next-samples") {
|
|
1236
|
-
params->fill_with_next_samples = true;
|
|
1237
|
-
} else if (arg == "--separate-with-eos") {
|
|
1238
|
-
params->separate_with_eos = true;
|
|
1239
|
-
} else if (arg == "--separate-with-bos") {
|
|
1240
|
-
params->separate_with_bos = true;
|
|
1241
|
-
} else if (arg == "--no-separate-with-eos") {
|
|
1242
|
-
params->separate_with_eos = false;
|
|
1243
|
-
} else if (arg == "--no-separate-with-bos") {
|
|
1244
|
-
params->separate_with_bos = false;
|
|
1245
|
-
} else if (arg == "--sample-random-offsets") {
|
|
1246
|
-
params->sample_random_offsets = true;
|
|
1247
|
-
} else if (arg == "--force-reshuffle") {
|
|
1248
|
-
params->force_reshuffle = true;
|
|
1249
|
-
} else if (arg == "--no-flash") {
|
|
1250
|
-
params->use_flash = false;
|
|
1251
|
-
} else if (arg == "--use-flash") {
|
|
1252
|
-
params->use_flash = true;
|
|
1253
|
-
} else if (arg == "--no-checkpointing") {
|
|
1254
|
-
params->use_checkpointing = false;
|
|
1255
|
-
} else if (arg == "--use-checkpointing") {
|
|
1256
|
-
params->use_checkpointing = true;
|
|
1257
|
-
} else if (arg == "--warmup") {
|
|
1258
|
-
if (++i >= argc) {
|
|
1259
|
-
*invalid_param = true;
|
|
1260
|
-
return true;
|
|
1261
|
-
}
|
|
1262
|
-
params->warmup = std::stoi(argv[i]);
|
|
1263
|
-
} else if (arg == "--cos-decay-steps") {
|
|
1264
|
-
if (++i >= argc) {
|
|
1265
|
-
*invalid_param = true;
|
|
1266
|
-
return true;
|
|
1267
|
-
}
|
|
1268
|
-
params->cos_decay_steps = std::stoi(argv[i]);
|
|
1269
|
-
} else if (arg == "--cos-decay-restart") {
|
|
1270
|
-
if (++i >= argc) {
|
|
1271
|
-
*invalid_param = true;
|
|
1272
|
-
return true;
|
|
1273
|
-
}
|
|
1274
|
-
params->cos_decay_restart = std::stof(argv[i]);
|
|
1275
|
-
} else if (arg == "--cos-decay-min") {
|
|
1276
|
-
if (++i >= argc) {
|
|
1277
|
-
*invalid_param = true;
|
|
1278
|
-
return true;
|
|
1279
|
-
}
|
|
1280
|
-
params->cos_decay_min = std::stof(argv[i]);
|
|
1281
|
-
} else if (arg == "--enable-restart") {
|
|
1282
|
-
params->enable_restart = true;
|
|
1283
|
-
} else if (arg == "--disable-restart") {
|
|
1284
|
-
params->enable_restart = false;
|
|
1285
|
-
} else if (arg == "--opt-past") {
|
|
1286
|
-
if (++i >= argc) {
|
|
1287
|
-
*invalid_param = true;
|
|
1288
|
-
return true;
|
|
1289
|
-
}
|
|
1290
|
-
params->opt_past = std::stoi(argv[i]);
|
|
1291
|
-
} else if (arg == "--opt-delta") {
|
|
1292
|
-
if (++i >= argc) {
|
|
1293
|
-
*invalid_param = true;
|
|
1294
|
-
return true;
|
|
1295
|
-
}
|
|
1296
|
-
params->opt_delta = std::stof(argv[i]);
|
|
1297
|
-
} else if (arg == "--opt-max-no-improvement") {
|
|
1298
|
-
if (++i >= argc) {
|
|
1299
|
-
*invalid_param = true;
|
|
1300
|
-
return true;
|
|
1301
|
-
}
|
|
1302
|
-
params->opt_max_no_improvement = std::stoi(argv[i]);
|
|
1303
|
-
} else if (arg == "--adam-epsf") {
|
|
1304
|
-
if (++i >= argc) {
|
|
1305
|
-
*invalid_param = true;
|
|
1306
|
-
return true;
|
|
1307
|
-
}
|
|
1308
|
-
params->adam_eps_f = std::stof(argv[i]);
|
|
1309
|
-
} else if (arg == "--epochs") {
|
|
1310
|
-
if (++i >= argc) {
|
|
1311
|
-
*invalid_param = true;
|
|
1312
|
-
return true;
|
|
1313
|
-
}
|
|
1314
|
-
params->n_epochs = std::stoi(argv[i]);
|
|
1315
|
-
} else if (arg == "--adam-iter") {
|
|
1316
|
-
if (++i >= argc) {
|
|
1317
|
-
*invalid_param = true;
|
|
1318
|
-
return true;
|
|
1319
|
-
}
|
|
1320
|
-
params->adam_n_iter = std::stoi(argv[i]);
|
|
1321
|
-
} else if (arg == "--adam-alpha") {
|
|
1322
|
-
if (++i >= argc) {
|
|
1323
|
-
*invalid_param = true;
|
|
1324
|
-
return true;
|
|
1325
|
-
}
|
|
1326
|
-
params->adam_alpha = std::stof(argv[i]);
|
|
1327
|
-
} else if (arg == "--adam-min-alpha") {
|
|
1328
|
-
if (++i >= argc) {
|
|
1329
|
-
*invalid_param = true;
|
|
1330
|
-
return true;
|
|
1331
|
-
}
|
|
1332
|
-
params->adam_min_alpha = std::stof(argv[i]);
|
|
1333
|
-
} else if (arg == "--adam-decay") {
|
|
1334
|
-
if (++i >= argc) {
|
|
1335
|
-
*invalid_param = true;
|
|
1336
|
-
return true;
|
|
1337
|
-
}
|
|
1338
|
-
params->adam_decay = std::stof(argv[i]);
|
|
1339
|
-
} else if (arg == "--adam-decay-min-ndim") {
|
|
1340
|
-
if (++i >= argc) {
|
|
1341
|
-
*invalid_param = true;
|
|
1342
|
-
return true;
|
|
1343
|
-
}
|
|
1344
|
-
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
|
1345
|
-
} else if (arg == "--adam-beta1") {
|
|
1346
|
-
if (++i >= argc) {
|
|
1347
|
-
*invalid_param = true;
|
|
1348
|
-
return true;
|
|
1349
|
-
}
|
|
1350
|
-
params->adam_beta1 = std::stof(argv[i]);
|
|
1351
|
-
} else if (arg == "--adam-beta2") {
|
|
1352
|
-
if (++i >= argc) {
|
|
1353
|
-
*invalid_param = true;
|
|
1354
|
-
return true;
|
|
1355
|
-
}
|
|
1356
|
-
params->adam_beta2 = std::stof(argv[i]);
|
|
1357
|
-
} else if (arg == "--adam-gclip") {
|
|
1358
|
-
if (++i >= argc) {
|
|
1359
|
-
*invalid_param = true;
|
|
1360
|
-
return true;
|
|
1361
|
-
}
|
|
1362
|
-
params->adam_gclip = std::stof(argv[i]);
|
|
1363
|
-
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
|
|
1364
|
-
if (++i >= argc) {
|
|
1365
|
-
*invalid_param = true;
|
|
1366
|
-
return true;
|
|
1367
|
-
}
|
|
1368
|
-
if (llama_supports_gpu_offload()) {
|
|
1369
|
-
params->n_gpu_layers = std::stoi(argv[i]);
|
|
1370
|
-
} else {
|
|
1371
|
-
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
|
1372
|
-
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
|
1373
|
-
}
|
|
1374
|
-
} else if (arg == "-h" || arg == "--help") {
|
|
1375
|
-
params->print_usage = true;
|
|
1376
|
-
return true;
|
|
1377
|
-
} else {
|
|
1378
|
-
return false;
|
|
1379
|
-
}
|
|
1380
|
-
return true;
|
|
1381
|
-
}
|
|
1382
|
-
|
|
1383
|
-
void finish_processing_train_args(struct train_params_common * params) {
|
|
1384
|
-
if (params->escape) {
|
|
1385
|
-
string_process_escapes(params->sample_start);
|
|
1386
|
-
}
|
|
1387
|
-
}
|
|
1388
|
-
|
|
1389
|
-
void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
|
|
1390
|
-
struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
|
|
1391
|
-
struct train_params_common * params = data->params;
|
|
1392
|
-
struct train_state * train = data->train;
|
|
1393
|
-
struct ggml_opt_context * opt = train->opt;
|
|
1394
|
-
int n_batch = params->n_batch;
|
|
1395
|
-
int n_ctx = params->n_ctx;
|
|
1396
|
-
|
|
1397
|
-
if (accum_step == 0) {
|
|
1398
|
-
// time measurement
|
|
1399
|
-
int64_t now = ggml_time_ms();
|
|
1400
|
-
if (now > data->last_time && opt->iter > data->first_iter) {
|
|
1401
|
-
double dt = (double) (now - data->last_time);
|
|
1402
|
-
if (data->millis_per_iter == 0.0) {
|
|
1403
|
-
data->millis_per_iter = dt;
|
|
1404
|
-
} else {
|
|
1405
|
-
const double gain = 0.7;
|
|
1406
|
-
data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
|
|
1407
|
-
}
|
|
1408
|
-
}
|
|
1409
|
-
|
|
1410
|
-
double remaining_millis = 0.0;
|
|
1411
|
-
if (data->millis_per_iter > 0.0) {
|
|
1412
|
-
const int n_iter = params->adam_n_iter;
|
|
1413
|
-
const int done_iter = opt->iter - data->first_iter;
|
|
1414
|
-
const int remaining_iter = n_iter - done_iter;
|
|
1415
|
-
remaining_millis = remaining_iter * data->millis_per_iter;
|
|
1416
|
-
}
|
|
1417
|
-
|
|
1418
|
-
// file saving
|
|
1419
|
-
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
|
1420
|
-
if (save_now) {
|
|
1421
|
-
int new_iters = opt->iter - data->last_save_iter;
|
|
1422
|
-
train->train_its += new_iters;
|
|
1423
|
-
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
|
|
1424
|
-
|
|
1425
|
-
if (data->save_cb) {
|
|
1426
|
-
data->save_cb(data->save_data, train);
|
|
1427
|
-
}
|
|
1428
|
-
|
|
1429
|
-
data->last_save_iter = opt->iter;
|
|
1430
|
-
}
|
|
1431
|
-
|
|
1432
|
-
// exclude file saving from time measurement, by measuring last_time after saving
|
|
1433
|
-
data->last_time = ggml_time_ms();
|
|
1434
|
-
|
|
1435
|
-
*sched = learning_schedule(
|
|
1436
|
-
opt->iter,
|
|
1437
|
-
params->warmup,
|
|
1438
|
-
params->cos_decay_steps,
|
|
1439
|
-
params->adam_alpha,
|
|
1440
|
-
params->adam_min_alpha,
|
|
1441
|
-
params->cos_decay_min,
|
|
1442
|
-
params->cos_decay_restart,
|
|
1443
|
-
params->enable_restart);
|
|
1444
|
-
|
|
1445
|
-
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
|
1446
|
-
if (impr_plot > 0) impr_plot = 0;
|
|
1447
|
-
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0;
|
|
1448
|
-
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
|
|
1449
|
-
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
|
|
1450
|
-
*sched, opt->loss_after);
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
if (data->millis_per_iter > 0) {
|
|
1454
|
-
printf(" dt=");
|
|
1455
|
-
print_duration(data->millis_per_iter);
|
|
1456
|
-
printf(" eta=");
|
|
1457
|
-
print_duration(remaining_millis);
|
|
1458
|
-
}
|
|
1459
|
-
|
|
1460
|
-
float improvement = opt->loss_before - opt->loss_after;
|
|
1461
|
-
const float plot_scale = 10.0f;
|
|
1462
|
-
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
|
|
1463
|
-
printf(" |");
|
|
1464
|
-
for (int i=0; i<bar_len; ++i) {
|
|
1465
|
-
printf("-");
|
|
1466
|
-
}
|
|
1467
|
-
printf(">");
|
|
1468
|
-
printf("\n");
|
|
1469
|
-
}
|
|
1470
|
-
|
|
1471
|
-
int64_t used_samples = get_example_targets_batch(
|
|
1472
|
-
data->lctx,
|
|
1473
|
-
data->tokens_input,
|
|
1474
|
-
data->target_probs,
|
|
1475
|
-
train->shuffle_next_sample,
|
|
1476
|
-
data->shuffled_samples_offs,
|
|
1477
|
-
data->shuffled_samples_begin,
|
|
1478
|
-
data->shuffled_samples_size,
|
|
1479
|
-
data->samples_count,
|
|
1480
|
-
data->tokens_data,
|
|
1481
|
-
data->tokens_size,
|
|
1482
|
-
params->separate_with_eos,
|
|
1483
|
-
params->separate_with_bos,
|
|
1484
|
-
params->fill_with_next_samples,
|
|
1485
|
-
params->sample_random_offsets);
|
|
1486
|
-
|
|
1487
|
-
train->train_samples += used_samples;
|
|
1488
|
-
train->shuffle_next_sample += used_samples;
|
|
1489
|
-
|
|
1490
|
-
if (train->shuffle_next_sample >= train->shuffle_sample_count) {
|
|
1491
|
-
++train->train_epochs;
|
|
1492
|
-
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
|
|
1493
|
-
// note: we may have used some samples from the current shuffling more than once
|
|
1494
|
-
train->shuffle_rng_state_current = train->shuffle_rng_state_next;
|
|
1495
|
-
train->shuffle_rng_state_next = shuffle_samples(
|
|
1496
|
-
train->shuffle_rng_state_current,
|
|
1497
|
-
data->shuffled_samples_offs,
|
|
1498
|
-
data->shuffled_samples_begin,
|
|
1499
|
-
data->shuffled_samples_size,
|
|
1500
|
-
data->samples_begin,
|
|
1501
|
-
data->samples_size,
|
|
1502
|
-
data->samples_count);
|
|
1503
|
-
train->shuffle_next_sample = 0;
|
|
1504
|
-
}
|
|
1505
|
-
|
|
1506
|
-
const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs);
|
|
1507
|
-
if (last_epoch_reached) {
|
|
1508
|
-
// allow optimization iteration at last epoch to be completed before canceling
|
|
1509
|
-
if (data->iter_at_last_epoch < 0) {
|
|
1510
|
-
data->iter_at_last_epoch = opt->iter;
|
|
1511
|
-
} else if (opt->iter > data->iter_at_last_epoch) {
|
|
1512
|
-
*cancel = true;
|
|
1513
|
-
}
|
|
1514
|
-
}
|
|
1515
|
-
}
|