llama-cpp-capacitor 0.0.6 → 0.0.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -9
- package/cpp/LICENSE +21 -0
- package/cpp/README.md +4 -0
- package/cpp/anyascii.c +22223 -0
- package/cpp/anyascii.h +42 -0
- package/cpp/chat-parser.cpp +393 -0
- package/cpp/chat-parser.h +120 -0
- package/cpp/chat.cpp +2315 -0
- package/cpp/chat.h +221 -0
- package/cpp/common.cpp +1619 -0
- package/cpp/common.h +744 -0
- package/cpp/ggml-alloc.c +1028 -0
- package/cpp/ggml-alloc.h +76 -0
- package/cpp/ggml-backend-impl.h +255 -0
- package/cpp/ggml-backend-reg.cpp +600 -0
- package/cpp/ggml-backend.cpp +2118 -0
- package/cpp/ggml-backend.h +354 -0
- package/cpp/ggml-common.h +1878 -0
- package/cpp/ggml-cpp.h +39 -0
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2512 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +3650 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +1891 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +3820 -0
- package/cpp/ggml-cpu/arch/x86/repack.cpp +6307 -0
- package/cpp/ggml-cpu/arch-fallback.h +215 -0
- package/cpp/ggml-cpu/binary-ops.cpp +158 -0
- package/cpp/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml-cpu/common.h +73 -0
- package/cpp/ggml-cpu/ggml-cpu-impl.h +525 -0
- package/cpp/ggml-cpu/ggml-cpu.c +3578 -0
- package/cpp/ggml-cpu/ggml-cpu.cpp +672 -0
- package/cpp/ggml-cpu/ops.cpp +10587 -0
- package/cpp/ggml-cpu/ops.h +114 -0
- package/cpp/ggml-cpu/quants.c +1193 -0
- package/cpp/ggml-cpu/quants.h +97 -0
- package/cpp/ggml-cpu/repack.cpp +1982 -0
- package/cpp/ggml-cpu/repack.h +120 -0
- package/cpp/ggml-cpu/simd-mappings.h +1184 -0
- package/cpp/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml-cpu/traits.h +38 -0
- package/cpp/ggml-cpu/unary-ops.cpp +186 -0
- package/cpp/ggml-cpu/unary-ops.h +28 -0
- package/cpp/ggml-cpu/vec.cpp +348 -0
- package/cpp/ggml-cpu/vec.h +1121 -0
- package/cpp/ggml-cpu.h +145 -0
- package/cpp/ggml-impl.h +622 -0
- package/cpp/ggml-metal-impl.h +688 -0
- package/cpp/ggml-metal.h +66 -0
- package/cpp/ggml-metal.m +6833 -0
- package/cpp/ggml-opt.cpp +1093 -0
- package/cpp/ggml-opt.h +256 -0
- package/cpp/ggml-quants.c +5324 -0
- package/cpp/ggml-quants.h +106 -0
- package/cpp/ggml-threading.cpp +12 -0
- package/cpp/ggml-threading.h +14 -0
- package/cpp/ggml.c +7108 -0
- package/cpp/ggml.h +2492 -0
- package/cpp/gguf.cpp +1358 -0
- package/cpp/gguf.h +202 -0
- package/cpp/json-partial.cpp +256 -0
- package/cpp/json-partial.h +38 -0
- package/cpp/json-schema-to-grammar.cpp +985 -0
- package/cpp/json-schema-to-grammar.h +21 -0
- package/cpp/llama-adapter.cpp +388 -0
- package/cpp/llama-adapter.h +76 -0
- package/cpp/llama-arch.cpp +2355 -0
- package/cpp/llama-arch.h +499 -0
- package/cpp/llama-batch.cpp +875 -0
- package/cpp/llama-batch.h +160 -0
- package/cpp/llama-chat.cpp +783 -0
- package/cpp/llama-chat.h +65 -0
- package/cpp/llama-context.cpp +2748 -0
- package/cpp/llama-context.h +306 -0
- package/cpp/llama-cparams.cpp +5 -0
- package/cpp/llama-cparams.h +41 -0
- package/cpp/llama-cpp.h +30 -0
- package/cpp/llama-grammar.cpp +1229 -0
- package/cpp/llama-grammar.h +173 -0
- package/cpp/llama-graph.cpp +1891 -0
- package/cpp/llama-graph.h +810 -0
- package/cpp/llama-hparams.cpp +180 -0
- package/cpp/llama-hparams.h +233 -0
- package/cpp/llama-impl.cpp +167 -0
- package/cpp/llama-impl.h +61 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache-iswa.cpp +318 -0
- package/cpp/llama-kv-cache-iswa.h +135 -0
- package/cpp/llama-kv-cache.cpp +2059 -0
- package/cpp/llama-kv-cache.h +374 -0
- package/cpp/llama-kv-cells.h +491 -0
- package/cpp/llama-memory-hybrid.cpp +258 -0
- package/cpp/llama-memory-hybrid.h +137 -0
- package/cpp/llama-memory-recurrent.cpp +1146 -0
- package/cpp/llama-memory-recurrent.h +179 -0
- package/cpp/llama-memory.cpp +59 -0
- package/cpp/llama-memory.h +119 -0
- package/cpp/llama-mmap.cpp +600 -0
- package/cpp/llama-mmap.h +68 -0
- package/cpp/llama-model-loader.cpp +1164 -0
- package/cpp/llama-model-loader.h +170 -0
- package/cpp/llama-model-saver.cpp +282 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +19042 -0
- package/cpp/llama-model.h +491 -0
- package/cpp/llama-sampling.cpp +2575 -0
- package/cpp/llama-sampling.h +32 -0
- package/cpp/llama-vocab.cpp +3792 -0
- package/cpp/llama-vocab.h +176 -0
- package/cpp/llama.cpp +358 -0
- package/cpp/llama.h +1373 -0
- package/cpp/log.cpp +427 -0
- package/cpp/log.h +103 -0
- package/cpp/minja/chat-template.hpp +550 -0
- package/cpp/minja/minja.hpp +3009 -0
- package/cpp/nlohmann/json.hpp +25526 -0
- package/cpp/nlohmann/json_fwd.hpp +187 -0
- package/cpp/regex-partial.cpp +204 -0
- package/cpp/regex-partial.h +56 -0
- package/cpp/rn-completion.cpp +681 -0
- package/cpp/rn-completion.h +116 -0
- package/cpp/rn-llama.cpp +345 -0
- package/cpp/rn-llama.h +149 -0
- package/cpp/rn-mtmd.hpp +602 -0
- package/cpp/rn-tts.cpp +591 -0
- package/cpp/rn-tts.h +59 -0
- package/cpp/sampling.cpp +579 -0
- package/cpp/sampling.h +107 -0
- package/cpp/tools/mtmd/clip-impl.h +473 -0
- package/cpp/tools/mtmd/clip.cpp +4322 -0
- package/cpp/tools/mtmd/clip.h +106 -0
- package/cpp/tools/mtmd/miniaudio/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +769 -0
- package/cpp/tools/mtmd/mtmd-audio.h +47 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +460 -0
- package/cpp/tools/mtmd/mtmd-helper.h +91 -0
- package/cpp/tools/mtmd/mtmd.cpp +1066 -0
- package/cpp/tools/mtmd/mtmd.h +298 -0
- package/cpp/tools/mtmd/stb/stb_image.h +7988 -0
- package/cpp/unicode-data.cpp +7034 -0
- package/cpp/unicode-data.h +20 -0
- package/cpp/unicode.cpp +1061 -0
- package/cpp/unicode.h +68 -0
- package/package.json +2 -1
package/cpp/ggml-opt.cpp
ADDED
|
@@ -0,0 +1,1093 @@
|
|
|
1
|
+
#include "ggml-opt.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
#include "ggml-alloc.h"
|
|
5
|
+
#include "ggml-backend.h"
|
|
6
|
+
#include "ggml-impl.h"
|
|
7
|
+
|
|
8
|
+
#include <algorithm>
|
|
9
|
+
#include <cmath>
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <cinttypes>
|
|
12
|
+
#include <map>
|
|
13
|
+
#include <random>
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
struct lm_ggml_opt_dataset {
|
|
17
|
+
struct lm_ggml_context * ctx = nullptr;
|
|
18
|
+
lm_ggml_backend_buffer_t buf = nullptr;
|
|
19
|
+
struct lm_ggml_tensor * data = nullptr;
|
|
20
|
+
struct lm_ggml_tensor * labels = nullptr;
|
|
21
|
+
|
|
22
|
+
int64_t ndata = -1;
|
|
23
|
+
int64_t ndata_shard = -1;
|
|
24
|
+
size_t nbs_data = -1;
|
|
25
|
+
size_t nbs_labels = -1;
|
|
26
|
+
|
|
27
|
+
std::vector<int64_t> permutation;
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
struct lm_ggml_opt_context {
|
|
31
|
+
lm_ggml_backend_sched_t backend_sched = nullptr;
|
|
32
|
+
lm_ggml_cgraph * allocated_graph = nullptr;
|
|
33
|
+
lm_ggml_cgraph * allocated_graph_copy = nullptr;
|
|
34
|
+
struct lm_ggml_context * ctx_static = nullptr;
|
|
35
|
+
struct lm_ggml_context * ctx_cpu = nullptr;
|
|
36
|
+
struct lm_ggml_context * ctx_compute = nullptr;
|
|
37
|
+
struct lm_ggml_context * ctx_copy = nullptr;
|
|
38
|
+
lm_ggml_backend_buffer_t buf_static = nullptr;
|
|
39
|
+
lm_ggml_backend_buffer_t buf_cpu = nullptr;
|
|
40
|
+
std::mt19937 rng;
|
|
41
|
+
enum lm_ggml_opt_loss_type loss_type;
|
|
42
|
+
enum lm_ggml_opt_build_type build_type;
|
|
43
|
+
enum lm_ggml_opt_build_type build_type_alloc;
|
|
44
|
+
|
|
45
|
+
struct lm_ggml_tensor * inputs = nullptr;
|
|
46
|
+
struct lm_ggml_tensor * outputs = nullptr;
|
|
47
|
+
struct lm_ggml_tensor * labels = nullptr;
|
|
48
|
+
|
|
49
|
+
struct lm_ggml_tensor * loss = nullptr;
|
|
50
|
+
struct lm_ggml_tensor * pred = nullptr;
|
|
51
|
+
struct lm_ggml_tensor * ncorrect = nullptr;
|
|
52
|
+
|
|
53
|
+
struct lm_ggml_cgraph * gf = nullptr;
|
|
54
|
+
struct lm_ggml_cgraph * gb_grad = nullptr;
|
|
55
|
+
struct lm_ggml_cgraph * gb_opt = nullptr;
|
|
56
|
+
bool static_graphs = false;
|
|
57
|
+
bool eval_ready = false;
|
|
58
|
+
std::vector<struct lm_ggml_tensor *> grad_accs;
|
|
59
|
+
std::vector<struct lm_ggml_tensor *> grad_m;
|
|
60
|
+
std::vector<struct lm_ggml_tensor *> grad_v;
|
|
61
|
+
|
|
62
|
+
int64_t iter = 1;
|
|
63
|
+
int32_t opt_period = 1;
|
|
64
|
+
int32_t opt_i = 0;
|
|
65
|
+
bool loss_per_datapoint = false;
|
|
66
|
+
|
|
67
|
+
lm_ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
|
68
|
+
void * get_opt_pars_ud = nullptr;
|
|
69
|
+
struct lm_ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
|
70
|
+
|
|
71
|
+
enum lm_ggml_opt_optimizer_type optimizer = LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
struct lm_ggml_opt_result {
|
|
75
|
+
int64_t ndata = 0;
|
|
76
|
+
std::vector<float> loss;
|
|
77
|
+
std::vector<int32_t> pred;
|
|
78
|
+
int64_t ncorrect = 0;
|
|
79
|
+
|
|
80
|
+
int64_t opt_period = -1;
|
|
81
|
+
bool loss_per_datapoint = false;
|
|
82
|
+
};
|
|
83
|
+
|
|
84
|
+
// ====== Dataset ======
|
|
85
|
+
|
|
86
|
+
lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
|
|
87
|
+
enum lm_ggml_type type_data,
|
|
88
|
+
enum lm_ggml_type type_label,
|
|
89
|
+
int64_t ne_datapoint,
|
|
90
|
+
int64_t ne_label,
|
|
91
|
+
int64_t ndata,
|
|
92
|
+
int64_t ndata_shard) {
|
|
93
|
+
LM_GGML_ASSERT(ne_datapoint > 0);
|
|
94
|
+
LM_GGML_ASSERT(ne_label >= 0);
|
|
95
|
+
LM_GGML_ASSERT(ndata > 0);
|
|
96
|
+
LM_GGML_ASSERT(ndata_shard > 0);
|
|
97
|
+
|
|
98
|
+
lm_ggml_opt_dataset_t result = new lm_ggml_opt_dataset;
|
|
99
|
+
result->ndata = ndata;
|
|
100
|
+
result->ndata_shard = ndata_shard;
|
|
101
|
+
|
|
102
|
+
{
|
|
103
|
+
struct lm_ggml_init_params params = {
|
|
104
|
+
/*.mem_size =*/ 2*lm_ggml_tensor_overhead(),
|
|
105
|
+
/*.mem_buffer =*/ nullptr,
|
|
106
|
+
/*.no_alloc =*/ true,
|
|
107
|
+
};
|
|
108
|
+
result->ctx = lm_ggml_init(params);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
result->data = lm_ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
|
|
112
|
+
result->nbs_data = lm_ggml_nbytes(result->data) * ndata_shard/ndata;
|
|
113
|
+
|
|
114
|
+
if (ne_label > 0) {
|
|
115
|
+
result->labels = lm_ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
|
|
116
|
+
result->nbs_labels = lm_ggml_nbytes(result->labels) * ndata_shard/ndata;
|
|
117
|
+
} else {
|
|
118
|
+
result->labels = nullptr;
|
|
119
|
+
result->nbs_labels = 0;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
result->buf = lm_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, lm_ggml_backend_cpu_buffer_type());
|
|
123
|
+
|
|
124
|
+
const int64_t nshards = ndata/ndata_shard;
|
|
125
|
+
result->permutation.resize(nshards);
|
|
126
|
+
for (int64_t i = 0; i < nshards; ++i) {
|
|
127
|
+
result->permutation[i] = i;
|
|
128
|
+
}
|
|
129
|
+
return result;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset) {
|
|
133
|
+
lm_ggml_backend_buffer_free(dataset->buf);
|
|
134
|
+
lm_ggml_free(dataset->ctx);
|
|
135
|
+
delete dataset;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
int64_t lm_ggml_opt_dataset_ndata(lm_ggml_opt_dataset_t dataset) {
|
|
139
|
+
return dataset->ndata;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
struct lm_ggml_tensor * lm_ggml_opt_dataset_data(lm_ggml_opt_dataset_t dataset) {
|
|
143
|
+
return dataset->data;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
struct lm_ggml_tensor * lm_ggml_opt_dataset_labels(lm_ggml_opt_dataset_t dataset) {
|
|
147
|
+
return dataset->labels;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
void lm_ggml_opt_dataset_shuffle(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_dataset_t dataset, int64_t idata) {
|
|
151
|
+
LM_GGML_ASSERT(idata <= dataset->ndata);
|
|
152
|
+
|
|
153
|
+
if (idata < 0) {
|
|
154
|
+
std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
|
|
155
|
+
return;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
LM_GGML_ASSERT(idata % dataset->ndata_shard == 0);
|
|
159
|
+
const int64_t ishard_max = idata / dataset->ndata_shard;
|
|
160
|
+
std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml_tensor * data_batch, struct lm_ggml_tensor * labels_batch, int64_t ibatch) {
|
|
164
|
+
LM_GGML_ASSERT( data_batch && lm_ggml_is_contiguous(data_batch));
|
|
165
|
+
LM_GGML_ASSERT(!labels_batch || lm_ggml_is_contiguous(labels_batch));
|
|
166
|
+
LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
|
167
|
+
LM_GGML_ASSERT( data_batch->type == dataset->data->type);
|
|
168
|
+
LM_GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
|
|
169
|
+
|
|
170
|
+
const size_t nb_data_batch = lm_ggml_nbytes(data_batch);
|
|
171
|
+
LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
|
172
|
+
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
|
173
|
+
|
|
174
|
+
if (labels_batch) {
|
|
175
|
+
const size_t nb_labels_batch = lm_ggml_nbytes(labels_batch);
|
|
176
|
+
LM_GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
LM_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
|
180
|
+
|
|
181
|
+
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
|
182
|
+
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
|
183
|
+
|
|
184
|
+
const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
|
|
185
|
+
lm_ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
|
|
186
|
+
|
|
187
|
+
if (!labels_batch) {
|
|
188
|
+
continue;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
|
|
192
|
+
lm_ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
void lm_ggml_opt_dataset_get_batch_host(lm_ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
|
|
197
|
+
LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
|
198
|
+
LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
|
199
|
+
|
|
200
|
+
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
|
201
|
+
|
|
202
|
+
LM_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
|
203
|
+
|
|
204
|
+
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
|
205
|
+
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
|
206
|
+
|
|
207
|
+
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
|
|
208
|
+
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
|
|
209
|
+
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
|
|
210
|
+
|
|
211
|
+
if (!labels_batch) {
|
|
212
|
+
continue;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
|
|
216
|
+
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
|
|
217
|
+
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// ====== Model / Context ======
|
|
222
|
+
|
|
223
|
+
struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata) {
|
|
224
|
+
LM_GGML_UNUSED(userdata);
|
|
225
|
+
|
|
226
|
+
lm_ggml_opt_optimizer_params result;
|
|
227
|
+
|
|
228
|
+
result.adamw.alpha = 0.001f;
|
|
229
|
+
result.adamw.beta1 = 0.9f;
|
|
230
|
+
result.adamw.beta2 = 0.999f;
|
|
231
|
+
result.adamw.eps = 1e-8f;
|
|
232
|
+
result.adamw.wd = 0.0f;
|
|
233
|
+
|
|
234
|
+
result.sgd.alpha = 1e-3f;
|
|
235
|
+
result.sgd.wd = 0.0f;
|
|
236
|
+
|
|
237
|
+
return result;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_constant_optimizer_params(void * userdata) {
|
|
242
|
+
return *((struct lm_ggml_opt_optimizer_params *) userdata);
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
struct lm_ggml_opt_params lm_ggml_opt_default_params(
|
|
246
|
+
lm_ggml_backend_sched_t backend_sched,
|
|
247
|
+
enum lm_ggml_opt_loss_type loss_type) {
|
|
248
|
+
return {
|
|
249
|
+
/*backend_sched =*/ backend_sched,
|
|
250
|
+
/*ctx_compute =*/ nullptr,
|
|
251
|
+
/*inputs =*/ nullptr,
|
|
252
|
+
/*logits =*/ nullptr,
|
|
253
|
+
/*loss_type =*/ loss_type,
|
|
254
|
+
/*build_type =*/ LM_GGML_OPT_BUILD_TYPE_OPT,
|
|
255
|
+
/*opt_period =*/ 1,
|
|
256
|
+
/*get_opt_pars =*/ lm_ggml_opt_get_default_optimizer_params,
|
|
257
|
+
/*get_opt_pars_ud =*/ nullptr,
|
|
258
|
+
/*optimizer =*/ LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
259
|
+
};
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
static lm_ggml_tensor * map_tensor(std::map<lm_ggml_tensor *, lm_ggml_tensor *> & tensor_map, lm_ggml_context * ctx, lm_ggml_tensor * tensor) {
|
|
263
|
+
if (!tensor) {
|
|
264
|
+
return nullptr;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
if (tensor_map.find(tensor) != tensor_map.end()) {
|
|
268
|
+
return tensor_map[tensor];
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
lm_ggml_tensor * new_tensor = lm_ggml_dup_tensor(ctx, tensor);
|
|
272
|
+
tensor_map[tensor] = new_tensor;
|
|
273
|
+
|
|
274
|
+
new_tensor->op = tensor->op;
|
|
275
|
+
for (int i = 0; i < LM_GGML_MAX_DIMS; i++) {
|
|
276
|
+
new_tensor->nb[i] = tensor->nb[i];
|
|
277
|
+
}
|
|
278
|
+
new_tensor->flags = tensor->flags;
|
|
279
|
+
memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
|
|
280
|
+
strcpy(new_tensor->name, tensor->name);
|
|
281
|
+
new_tensor->data = tensor->data;
|
|
282
|
+
new_tensor->buffer = tensor->buffer;
|
|
283
|
+
new_tensor->extra = tensor->extra;
|
|
284
|
+
new_tensor->view_offs = tensor->view_offs;
|
|
285
|
+
new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
|
|
286
|
+
for (int i = 0; i < LM_GGML_MAX_SRC; i++) {
|
|
287
|
+
new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
return new_tensor;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * src) {
|
|
294
|
+
std::map<lm_ggml_tensor *, lm_ggml_tensor *> tensor_map;
|
|
295
|
+
|
|
296
|
+
lm_ggml_cgraph * dst = lm_ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
|
|
297
|
+
|
|
298
|
+
for (int i = 0; i < src->n_leafs; i++) {
|
|
299
|
+
lm_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
|
|
300
|
+
}
|
|
301
|
+
LM_GGML_ASSERT(dst->n_leafs == src->n_leafs);
|
|
302
|
+
for (int i = 0; i < src->n_nodes; i++) {
|
|
303
|
+
lm_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
|
|
304
|
+
}
|
|
305
|
+
LM_GGML_ASSERT(dst->n_nodes == src->n_nodes);
|
|
306
|
+
for (int i = 0; i < src->n_nodes; ++i) {
|
|
307
|
+
const size_t igrad_src = lm_ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
|
|
308
|
+
const size_t igrad_dst = lm_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
|
|
309
|
+
|
|
310
|
+
LM_GGML_ASSERT(igrad_src != LM_GGML_HASHSET_FULL);
|
|
311
|
+
LM_GGML_ASSERT(lm_ggml_bitset_get(src->visited_hash_set.used, igrad_src));
|
|
312
|
+
LM_GGML_ASSERT(igrad_dst != LM_GGML_HASHSET_FULL);
|
|
313
|
+
LM_GGML_ASSERT(lm_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
|
|
314
|
+
|
|
315
|
+
dst->grads[igrad_dst] = src->grads[igrad_src];
|
|
316
|
+
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
return dst;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
static void lm_ggml_opt_build(lm_ggml_opt_context_t opt_ctx) {
|
|
323
|
+
LM_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with lm_ggml_opt_prepare_alloc");
|
|
324
|
+
LM_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
|
325
|
+
|
|
326
|
+
const enum lm_ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
|
327
|
+
|
|
328
|
+
const bool accumulate = opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD &&
|
|
329
|
+
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
|
330
|
+
|
|
331
|
+
const bool need_momenta = opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT &&
|
|
332
|
+
opt_ctx->optimizer == LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
333
|
+
|
|
334
|
+
lm_ggml_set_input(opt_ctx->inputs);
|
|
335
|
+
lm_ggml_set_output(opt_ctx->outputs);
|
|
336
|
+
|
|
337
|
+
int n_param = 0;
|
|
338
|
+
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
|
|
339
|
+
const struct lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
340
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
|
341
|
+
n_param++;
|
|
342
|
+
}
|
|
343
|
+
LM_GGML_ASSERT(!(node->flags & LM_GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
if (!opt_ctx->ctx_static) {
|
|
347
|
+
// The static context is used for:
|
|
348
|
+
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
|
|
349
|
+
// - optimizer momenta (2 tensors per param)
|
|
350
|
+
// - labels (if using static graphs)
|
|
351
|
+
// - loss (if using static graphs, up to 5 tensors)
|
|
352
|
+
// - pred (if using static graphs)
|
|
353
|
+
// - ncorrect (if using static graphs, 2 tensors).
|
|
354
|
+
constexpr size_t n_loss = 1;
|
|
355
|
+
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
|
356
|
+
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
|
357
|
+
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * lm_ggml_tensor_overhead();
|
|
358
|
+
struct lm_ggml_init_params params = {
|
|
359
|
+
/*.mem_size =*/ size_meta,
|
|
360
|
+
/*.mem_buffer =*/ nullptr,
|
|
361
|
+
/*.no_alloc =*/ true,
|
|
362
|
+
};
|
|
363
|
+
opt_ctx->ctx_static = lm_ggml_init(params);
|
|
364
|
+
}
|
|
365
|
+
LM_GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
|
|
366
|
+
|
|
367
|
+
{
|
|
368
|
+
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
|
|
369
|
+
// It is used for:
|
|
370
|
+
// - optimizer parameters (1 shared for all optimizer invocations)
|
|
371
|
+
const size_t size_meta = 1 * lm_ggml_tensor_overhead();
|
|
372
|
+
struct lm_ggml_init_params params = {
|
|
373
|
+
/*.mem_size =*/ size_meta,
|
|
374
|
+
/*.mem_buffer =*/ nullptr,
|
|
375
|
+
/*.no_alloc =*/ true,
|
|
376
|
+
};
|
|
377
|
+
lm_ggml_free(opt_ctx->ctx_cpu);
|
|
378
|
+
opt_ctx->ctx_cpu = lm_ggml_init(params);
|
|
379
|
+
|
|
380
|
+
lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
|
381
|
+
opt_ctx->buf_cpu = nullptr;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
struct lm_ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
|
|
385
|
+
|
|
386
|
+
switch (opt_ctx->loss_type) {
|
|
387
|
+
case LM_GGML_OPT_LOSS_TYPE_MEAN: {
|
|
388
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
|
|
389
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum");
|
|
390
|
+
const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
|
|
391
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
|
|
392
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_mean");
|
|
393
|
+
opt_ctx->loss_per_datapoint = true;
|
|
394
|
+
break;
|
|
395
|
+
}
|
|
396
|
+
case LM_GGML_OPT_LOSS_TYPE_SUM: {
|
|
397
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
|
|
398
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum");
|
|
399
|
+
opt_ctx->loss_per_datapoint = false;
|
|
400
|
+
break;
|
|
401
|
+
}
|
|
402
|
+
case LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
|
|
403
|
+
opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
|
404
|
+
lm_ggml_set_input(opt_ctx->labels);
|
|
405
|
+
lm_ggml_set_name(opt_ctx->labels, "labels");
|
|
406
|
+
opt_ctx->loss = lm_ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
|
407
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
|
|
408
|
+
if (opt_ctx->opt_period > 1) {
|
|
409
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
|
|
410
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
|
|
411
|
+
}
|
|
412
|
+
opt_ctx->loss_per_datapoint = true;
|
|
413
|
+
break;
|
|
414
|
+
}
|
|
415
|
+
case LM_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
|
|
416
|
+
opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
|
417
|
+
lm_ggml_set_input(opt_ctx->labels);
|
|
418
|
+
lm_ggml_set_name(opt_ctx->labels, "labels");
|
|
419
|
+
opt_ctx->loss = lm_ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
|
420
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_error");
|
|
421
|
+
opt_ctx->loss = lm_ggml_sqr(ctx_results, opt_ctx->loss);
|
|
422
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_squared_error");
|
|
423
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->loss);
|
|
424
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
|
|
425
|
+
const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
|
|
426
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
|
|
427
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
|
|
428
|
+
opt_ctx->loss_per_datapoint = true;
|
|
429
|
+
break;
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
lm_ggml_set_output(opt_ctx->loss);
|
|
433
|
+
lm_ggml_set_loss(opt_ctx->loss);
|
|
434
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
|
|
435
|
+
|
|
436
|
+
if (opt_ctx->loss_type == LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
|
|
437
|
+
opt_ctx->pred = lm_ggml_argmax(ctx_results, opt_ctx->outputs);
|
|
438
|
+
lm_ggml_set_name(opt_ctx->pred, "pred");
|
|
439
|
+
lm_ggml_set_output(opt_ctx->pred);
|
|
440
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
|
|
441
|
+
|
|
442
|
+
opt_ctx->ncorrect = lm_ggml_count_equal(ctx_results, opt_ctx->pred, lm_ggml_argmax(ctx_results, opt_ctx->labels));
|
|
443
|
+
lm_ggml_set_name(opt_ctx->ncorrect, "ncorrect");
|
|
444
|
+
lm_ggml_set_output(opt_ctx->ncorrect);
|
|
445
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
if (opt_ctx->buf_static) {
|
|
449
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
|
|
450
|
+
return;
|
|
451
|
+
}
|
|
452
|
+
} else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
|
|
453
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
|
|
454
|
+
opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
455
|
+
return;
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
if (opt_ctx->grad_accs.empty()) {
|
|
459
|
+
LM_GGML_ASSERT(opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD);
|
|
460
|
+
|
|
461
|
+
const int n_nodes = opt_ctx->gf->n_nodes;
|
|
462
|
+
opt_ctx->grad_accs.resize(n_nodes);
|
|
463
|
+
for (int i = 0; i < n_nodes; ++i) {
|
|
464
|
+
lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
465
|
+
if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) {
|
|
466
|
+
opt_ctx->grad_accs[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
|
467
|
+
} else {
|
|
468
|
+
opt_ctx->grad_accs[i] = nullptr;
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
if (need_momenta && opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_OPT) {
|
|
473
|
+
opt_ctx->grad_m.resize(n_nodes);
|
|
474
|
+
opt_ctx->grad_v.resize(n_nodes);
|
|
475
|
+
for (int i = 0; i < n_nodes; ++i) {
|
|
476
|
+
lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
477
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
|
478
|
+
opt_ctx->grad_m[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
|
479
|
+
opt_ctx->grad_v[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
|
480
|
+
} else {
|
|
481
|
+
opt_ctx->grad_m[i] = nullptr;
|
|
482
|
+
opt_ctx->grad_v[i] = nullptr;
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
|
489
|
+
opt_ctx->gb_grad = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
|
490
|
+
lm_ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
|
491
|
+
|
|
492
|
+
if (opt_ctx->buf_static) {
|
|
493
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) {
|
|
494
|
+
return;
|
|
495
|
+
}
|
|
496
|
+
} else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_GRAD) {
|
|
497
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
498
|
+
lm_ggml_graph_reset(opt_ctx->gb_grad);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
LM_GGML_ASSERT(opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT);
|
|
502
|
+
|
|
503
|
+
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
|
504
|
+
opt_ctx->gb_opt = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
|
505
|
+
|
|
506
|
+
opt_ctx->opt_step_params = lm_ggml_new_tensor_1d(opt_ctx->ctx_cpu, LM_GGML_TYPE_F32, need_momenta ? 7 : 2);
|
|
507
|
+
lm_ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
|
508
|
+
lm_ggml_set_input(adamw_params);
|
|
509
|
+
const char * optimizer_name = lm_ggml_opt_optimizer_name(opt_ctx->optimizer);
|
|
510
|
+
lm_ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
|
511
|
+
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
|
512
|
+
struct lm_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
|
513
|
+
struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
|
514
|
+
|
|
515
|
+
if (grad && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
|
|
516
|
+
struct lm_ggml_tensor * m = nullptr;
|
|
517
|
+
struct lm_ggml_tensor * v = nullptr;
|
|
518
|
+
if (need_momenta) {
|
|
519
|
+
m = opt_ctx->grad_m[i];
|
|
520
|
+
v = opt_ctx->grad_v[i];
|
|
521
|
+
lm_ggml_format_name(m, "AdamW m for %s", node->name);
|
|
522
|
+
lm_ggml_format_name(v, "AdamW v for %s", node->name);
|
|
523
|
+
}
|
|
524
|
+
struct lm_ggml_tensor * opt_step;
|
|
525
|
+
switch (optimizer) {
|
|
526
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
527
|
+
opt_step = lm_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
|
528
|
+
break;
|
|
529
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
530
|
+
opt_step = lm_ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
|
531
|
+
break;
|
|
532
|
+
default:
|
|
533
|
+
LM_GGML_ABORT("fatal error");
|
|
534
|
+
}
|
|
535
|
+
lm_ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
|
536
|
+
lm_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
if (!opt_ctx->buf_static) {
|
|
541
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
|
|
542
|
+
opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
543
|
+
lm_ggml_graph_reset(opt_ctx->gb_opt);
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
opt_ctx->buf_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, lm_ggml_backend_cpu_buffer_type());
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) {
|
|
550
|
+
lm_ggml_opt_context_t result = new struct lm_ggml_opt_context;
|
|
551
|
+
result->backend_sched = params.backend_sched;
|
|
552
|
+
result->ctx_compute = params.ctx_compute;
|
|
553
|
+
result->loss_type = params.loss_type;
|
|
554
|
+
result->build_type = params.build_type;
|
|
555
|
+
result->build_type_alloc = params.build_type;
|
|
556
|
+
result->inputs = params.inputs;
|
|
557
|
+
result->outputs = params.outputs;
|
|
558
|
+
result->opt_period = params.opt_period;
|
|
559
|
+
result->get_opt_pars = params.get_opt_pars;
|
|
560
|
+
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
561
|
+
result->optimizer = params.optimizer;
|
|
562
|
+
|
|
563
|
+
LM_GGML_ASSERT(result->opt_period >= 1);
|
|
564
|
+
|
|
565
|
+
result->static_graphs = result->ctx_compute;
|
|
566
|
+
|
|
567
|
+
if (!result->static_graphs) {
|
|
568
|
+
LM_GGML_ASSERT(!result->inputs);
|
|
569
|
+
LM_GGML_ASSERT(!result->outputs);
|
|
570
|
+
return result;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
LM_GGML_ASSERT(result->inputs);
|
|
574
|
+
LM_GGML_ASSERT(result->outputs);
|
|
575
|
+
|
|
576
|
+
result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
|
577
|
+
lm_ggml_build_forward_expand(result->gf, result->outputs);
|
|
578
|
+
|
|
579
|
+
lm_ggml_opt_build(result);
|
|
580
|
+
|
|
581
|
+
return result;
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx) {
|
|
585
|
+
if (opt_ctx == nullptr) {
|
|
586
|
+
return;
|
|
587
|
+
}
|
|
588
|
+
lm_ggml_backend_buffer_free(opt_ctx->buf_static);
|
|
589
|
+
lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
|
590
|
+
lm_ggml_free(opt_ctx->ctx_static);
|
|
591
|
+
lm_ggml_free(opt_ctx->ctx_cpu);
|
|
592
|
+
delete opt_ctx;
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer) {
|
|
596
|
+
if (optimizer) {
|
|
597
|
+
lm_ggml_graph_reset(opt_ctx->gb_opt);
|
|
598
|
+
opt_ctx->iter = 1;
|
|
599
|
+
} else {
|
|
600
|
+
lm_ggml_graph_reset(opt_ctx->gb_grad);
|
|
601
|
+
}
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
bool lm_ggml_opt_static_graphs(lm_ggml_opt_context_t opt_ctx) {
|
|
605
|
+
return opt_ctx->static_graphs;
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
struct lm_ggml_tensor * lm_ggml_opt_inputs(lm_ggml_opt_context_t opt_ctx) {
|
|
609
|
+
return opt_ctx->inputs;
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
struct lm_ggml_tensor * lm_ggml_opt_outputs(lm_ggml_opt_context_t opt_ctx) {
|
|
613
|
+
return opt_ctx->outputs;
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
struct lm_ggml_tensor * lm_ggml_opt_labels(lm_ggml_opt_context_t opt_ctx) {
|
|
617
|
+
return opt_ctx->labels;
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
struct lm_ggml_tensor * lm_ggml_opt_loss(lm_ggml_opt_context_t opt_ctx) {
|
|
621
|
+
return opt_ctx->loss;
|
|
622
|
+
}
|
|
623
|
+
|
|
624
|
+
struct lm_ggml_tensor * lm_ggml_opt_pred(lm_ggml_opt_context_t opt_ctx) {
|
|
625
|
+
return opt_ctx->pred;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
struct lm_ggml_tensor * lm_ggml_opt_ncorrect(lm_ggml_opt_context_t opt_ctx) {
|
|
629
|
+
return opt_ctx->ncorrect;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
struct lm_ggml_tensor * lm_ggml_opt_grad_acc(lm_ggml_opt_context_t opt_ctx, struct lm_ggml_tensor * node) {
|
|
633
|
+
return lm_ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
// ====== Optimization Result ======
|
|
637
|
+
|
|
638
|
+
lm_ggml_opt_result_t lm_ggml_opt_result_init() {
|
|
639
|
+
return new lm_ggml_opt_result;
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
void lm_ggml_opt_result_free(lm_ggml_opt_result_t result) {
|
|
643
|
+
delete result;
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
void lm_ggml_opt_result_reset(lm_ggml_opt_result_t result) {
|
|
647
|
+
result->ndata = 0;
|
|
648
|
+
result->loss.clear();
|
|
649
|
+
result->pred.clear();
|
|
650
|
+
result->ncorrect = 0;
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
void lm_ggml_opt_result_ndata(lm_ggml_opt_result_t result, int64_t * ndata) {
|
|
654
|
+
*ndata = result->ndata;
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
void lm_ggml_opt_result_loss(lm_ggml_opt_result_t result, double * loss, double * unc) {
|
|
658
|
+
const int64_t nbatches = result->loss.size(); // Number of physical batches.
|
|
659
|
+
|
|
660
|
+
if (nbatches == 0) {
|
|
661
|
+
*loss = 0.0;
|
|
662
|
+
*unc = NAN;
|
|
663
|
+
return;
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
double sum = 0.0;
|
|
667
|
+
double sum_squared = 0.0;
|
|
668
|
+
|
|
669
|
+
for (const float & loss : result->loss) {
|
|
670
|
+
// If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
|
|
671
|
+
const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
|
|
672
|
+
sum += loss_scaled;
|
|
673
|
+
sum_squared += loss_scaled*loss_scaled;
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
const double mean = sum/nbatches;
|
|
677
|
+
*loss = result->loss_per_datapoint ? mean : sum;
|
|
678
|
+
|
|
679
|
+
if (!unc) {
|
|
680
|
+
return;
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
if (nbatches < 2) {
|
|
684
|
+
*unc = NAN;
|
|
685
|
+
return;
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
|
|
689
|
+
*unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
void lm_ggml_opt_result_pred(lm_ggml_opt_result_t result, int32_t * pred) {
|
|
693
|
+
for (size_t i = 0; i < result->pred.size(); ++i) {
|
|
694
|
+
pred[i] = result->pred[i];
|
|
695
|
+
}
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
void lm_ggml_opt_result_accuracy(lm_ggml_opt_result_t result, double * accuracy, double * unc) {
|
|
699
|
+
*accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
|
|
700
|
+
|
|
701
|
+
if (!unc) {
|
|
702
|
+
return;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
*unc = result->ncorrect >= 0 && result->ndata >= 2 ?
|
|
706
|
+
sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
// ====== Computation ======
|
|
710
|
+
|
|
711
|
+
void lm_ggml_opt_prepare_alloc(
|
|
712
|
+
lm_ggml_opt_context_t opt_ctx,
|
|
713
|
+
struct lm_ggml_context * ctx_compute,
|
|
714
|
+
struct lm_ggml_cgraph * gf,
|
|
715
|
+
struct lm_ggml_tensor * inputs,
|
|
716
|
+
struct lm_ggml_tensor * outputs) {
|
|
717
|
+
LM_GGML_ASSERT(!opt_ctx->static_graphs);
|
|
718
|
+
opt_ctx->ctx_compute = ctx_compute;
|
|
719
|
+
opt_ctx->gf = gf;
|
|
720
|
+
opt_ctx->inputs = inputs;
|
|
721
|
+
opt_ctx->outputs = outputs;
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
void lm_ggml_opt_alloc(lm_ggml_opt_context_t opt_ctx, bool backward) {
|
|
725
|
+
LM_GGML_ASSERT(!opt_ctx->eval_ready);
|
|
726
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
|
|
727
|
+
lm_ggml_graph_reset(opt_ctx->gb_grad);
|
|
728
|
+
}
|
|
729
|
+
if (backward) {
|
|
730
|
+
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
|
731
|
+
opt_ctx->build_type = opt_i_next == 0 ? LM_GGML_OPT_BUILD_TYPE_OPT : LM_GGML_OPT_BUILD_TYPE_GRAD;
|
|
732
|
+
} else {
|
|
733
|
+
opt_ctx->build_type = LM_GGML_OPT_BUILD_TYPE_FORWARD;
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
if (!opt_ctx->static_graphs) {
|
|
737
|
+
lm_ggml_opt_build(opt_ctx);
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
struct lm_ggml_cgraph * graph = nullptr;
|
|
741
|
+
switch (opt_ctx->build_type) {
|
|
742
|
+
case LM_GGML_OPT_BUILD_TYPE_FORWARD: {
|
|
743
|
+
graph = opt_ctx->gf;
|
|
744
|
+
} break;
|
|
745
|
+
case LM_GGML_OPT_BUILD_TYPE_GRAD: {
|
|
746
|
+
graph = opt_ctx->gb_grad;
|
|
747
|
+
} break;
|
|
748
|
+
case LM_GGML_OPT_BUILD_TYPE_OPT: {
|
|
749
|
+
graph = opt_ctx->gb_opt;
|
|
750
|
+
} break;
|
|
751
|
+
}
|
|
752
|
+
LM_GGML_ASSERT(graph);
|
|
753
|
+
|
|
754
|
+
if (opt_ctx->allocated_graph == graph) {
|
|
755
|
+
opt_ctx->eval_ready = true;
|
|
756
|
+
return;
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
lm_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
|
760
|
+
|
|
761
|
+
if (opt_ctx->static_graphs) {
|
|
762
|
+
lm_ggml_init_params params = {
|
|
763
|
+
/*.mem_size =*/ graph->size*lm_ggml_tensor_overhead() + lm_ggml_graph_overhead_custom(graph->size, graph->grads),
|
|
764
|
+
/*.mem_buffer =*/ nullptr,
|
|
765
|
+
/*.no_alloc =*/ true,
|
|
766
|
+
};
|
|
767
|
+
lm_ggml_free(opt_ctx->ctx_copy);
|
|
768
|
+
opt_ctx->ctx_copy = lm_ggml_init(params);
|
|
769
|
+
|
|
770
|
+
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
|
771
|
+
} else {
|
|
772
|
+
opt_ctx->allocated_graph_copy = graph;
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
776
|
+
opt_ctx->allocated_graph = graph;
|
|
777
|
+
|
|
778
|
+
opt_ctx->eval_ready = true;
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
void lm_ggml_opt_eval(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result) {
|
|
782
|
+
LM_GGML_ASSERT(opt_ctx->eval_ready);
|
|
783
|
+
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
|
784
|
+
const lm_ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
|
785
|
+
|
|
786
|
+
switch (opt_ctx->optimizer) {
|
|
787
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
|
788
|
+
LM_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
|
789
|
+
LM_GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
|
790
|
+
LM_GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
|
791
|
+
LM_GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
|
792
|
+
LM_GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
|
793
|
+
LM_GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
|
794
|
+
LM_GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
|
795
|
+
LM_GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
|
796
|
+
|
|
797
|
+
// beta1, beta2 after applying warmup
|
|
798
|
+
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
|
799
|
+
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
|
800
|
+
|
|
801
|
+
float * adamw_par_data = lm_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
802
|
+
adamw_par_data[0] = opt_pars.adamw.alpha;
|
|
803
|
+
adamw_par_data[1] = opt_pars.adamw.beta1;
|
|
804
|
+
adamw_par_data[2] = opt_pars.adamw.beta2;
|
|
805
|
+
adamw_par_data[3] = opt_pars.adamw.eps;
|
|
806
|
+
adamw_par_data[4] = opt_pars.adamw.wd;
|
|
807
|
+
adamw_par_data[5] = beta1h;
|
|
808
|
+
adamw_par_data[6] = beta2h;
|
|
809
|
+
} break;
|
|
810
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
|
811
|
+
LM_GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
|
812
|
+
LM_GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
|
813
|
+
LM_GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
|
814
|
+
float * sgd = lm_ggml_get_data_f32(opt_ctx->opt_step_params);
|
|
815
|
+
sgd[0] = opt_pars.sgd.alpha;
|
|
816
|
+
sgd[1] = opt_pars.sgd.wd;
|
|
817
|
+
} break;
|
|
818
|
+
default:
|
|
819
|
+
LM_GGML_ABORT("fatal error");
|
|
820
|
+
}
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
lm_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
824
|
+
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
|
|
825
|
+
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
|
826
|
+
|
|
827
|
+
if (!opt_ctx->static_graphs) {
|
|
828
|
+
opt_ctx->gf = nullptr;
|
|
829
|
+
opt_ctx->gb_grad = nullptr;
|
|
830
|
+
opt_ctx->gb_opt = nullptr;
|
|
831
|
+
opt_ctx->allocated_graph = nullptr;
|
|
832
|
+
opt_ctx->allocated_graph_copy = nullptr;
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
opt_ctx->eval_ready = false;
|
|
836
|
+
|
|
837
|
+
if (!result) {
|
|
838
|
+
return;
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
if (result->ndata == 0) {
|
|
842
|
+
result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
|
|
843
|
+
result->opt_period = opt_ctx->opt_period;
|
|
844
|
+
} else {
|
|
845
|
+
LM_GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
|
|
846
|
+
LM_GGML_ASSERT(result->opt_period == opt_ctx->opt_period);
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
const int64_t ndata = opt_ctx->outputs->ne[1];
|
|
850
|
+
LM_GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
|
|
851
|
+
result->ndata += ndata;
|
|
852
|
+
|
|
853
|
+
LM_GGML_ASSERT(lm_ggml_is_scalar(opt_ctx->loss));
|
|
854
|
+
LM_GGML_ASSERT(opt_ctx->loss->type == LM_GGML_TYPE_F32);
|
|
855
|
+
float loss;
|
|
856
|
+
lm_ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, lm_ggml_nbytes(opt_ctx->loss));
|
|
857
|
+
result->loss.push_back(loss);
|
|
858
|
+
|
|
859
|
+
if (opt_ctx->pred) {
|
|
860
|
+
LM_GGML_ASSERT(opt_ctx->pred->type == LM_GGML_TYPE_I32);
|
|
861
|
+
std::vector<int32_t> pred(ndata);
|
|
862
|
+
lm_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, lm_ggml_nbytes(opt_ctx->pred));
|
|
863
|
+
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
|
|
867
|
+
result->ncorrect = -1;
|
|
868
|
+
return;
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
LM_GGML_ASSERT(lm_ggml_is_scalar(opt_ctx->ncorrect));
|
|
872
|
+
LM_GGML_ASSERT(opt_ctx->ncorrect->type == LM_GGML_TYPE_I64);
|
|
873
|
+
int64_t ncorrect;
|
|
874
|
+
lm_ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, lm_ggml_nbytes(opt_ctx->ncorrect));
|
|
875
|
+
result->ncorrect += ncorrect;
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
// ====== High-Level Functions ======
|
|
879
|
+
|
|
880
|
+
void lm_ggml_opt_epoch(
|
|
881
|
+
lm_ggml_opt_context_t opt_ctx,
|
|
882
|
+
lm_ggml_opt_dataset_t dataset,
|
|
883
|
+
lm_ggml_opt_result_t result_train,
|
|
884
|
+
lm_ggml_opt_result_t result_eval,
|
|
885
|
+
int64_t idata_split,
|
|
886
|
+
lm_ggml_opt_epoch_callback callback_train,
|
|
887
|
+
lm_ggml_opt_epoch_callback callback_eval) {
|
|
888
|
+
LM_GGML_ASSERT(lm_ggml_opt_static_graphs(opt_ctx) && "lm_ggml_opt_epoch requires static graphs");
|
|
889
|
+
struct lm_ggml_tensor * inputs = lm_ggml_opt_inputs(opt_ctx);
|
|
890
|
+
struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
|
|
891
|
+
struct lm_ggml_tensor * data = lm_ggml_opt_dataset_data(dataset);
|
|
892
|
+
LM_GGML_ASSERT(data->ne[0] == inputs->ne[0]);
|
|
893
|
+
|
|
894
|
+
const int64_t ndata = data->ne[1];
|
|
895
|
+
const int64_t ndata_batch = inputs->ne[1];
|
|
896
|
+
|
|
897
|
+
LM_GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
|
|
898
|
+
const int64_t nbatches = ndata/ndata_batch;
|
|
899
|
+
|
|
900
|
+
idata_split = idata_split < 0 ? ndata : idata_split;
|
|
901
|
+
LM_GGML_ASSERT(idata_split % ndata_batch == 0);
|
|
902
|
+
const int64_t ibatch_split = idata_split / ndata_batch;
|
|
903
|
+
|
|
904
|
+
int64_t ibatch = 0;
|
|
905
|
+
int64_t t_loop_start = lm_ggml_time_us();
|
|
906
|
+
for (; ibatch < ibatch_split; ++ibatch) {
|
|
907
|
+
lm_ggml_opt_alloc(opt_ctx, /*backward =*/ true);
|
|
908
|
+
lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
|
909
|
+
lm_ggml_opt_eval(opt_ctx, result_train);
|
|
910
|
+
if (callback_train) {
|
|
911
|
+
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
|
|
912
|
+
}
|
|
913
|
+
}
|
|
914
|
+
t_loop_start = lm_ggml_time_us();
|
|
915
|
+
for (; ibatch < nbatches; ++ibatch) {
|
|
916
|
+
lm_ggml_opt_alloc(opt_ctx, /*backward =*/ false);
|
|
917
|
+
lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
|
918
|
+
lm_ggml_opt_eval(opt_ctx, result_eval);
|
|
919
|
+
if (callback_eval) {
|
|
920
|
+
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
|
|
921
|
+
}
|
|
922
|
+
}
|
|
923
|
+
}
|
|
924
|
+
|
|
925
|
+
void lm_ggml_opt_epoch_callback_progress_bar(
|
|
926
|
+
bool train,
|
|
927
|
+
lm_ggml_opt_context_t opt_ctx,
|
|
928
|
+
lm_ggml_opt_dataset_t dataset,
|
|
929
|
+
lm_ggml_opt_result_t result,
|
|
930
|
+
int64_t ibatch,
|
|
931
|
+
int64_t ibatch_max,
|
|
932
|
+
int64_t t_start_us) {
|
|
933
|
+
fprintf(stderr, "%s[", train ? "train: " : "val: ");
|
|
934
|
+
|
|
935
|
+
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
|
|
936
|
+
constexpr int64_t bar_length = 8;
|
|
937
|
+
const int64_t ibatch8 = 8 * ibatch;
|
|
938
|
+
for (int64_t j = 0; j < bar_length; ++j) {
|
|
939
|
+
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
|
|
940
|
+
fprintf(stderr, "\u2588"); // full block
|
|
941
|
+
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
|
|
942
|
+
fprintf(stderr, "\u2589"); // 7/8 filled
|
|
943
|
+
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
|
|
944
|
+
fprintf(stderr, "\u258A"); // 6/8 filled
|
|
945
|
+
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
|
|
946
|
+
fprintf(stderr, "\u258B"); // 5/8 filled
|
|
947
|
+
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
|
|
948
|
+
fprintf(stderr, "\u258C"); // 4/8 filled
|
|
949
|
+
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
|
|
950
|
+
fprintf(stderr, "\u258D"); // 3/8 filled
|
|
951
|
+
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
|
|
952
|
+
fprintf(stderr, "\u258E"); // 2/8 filled
|
|
953
|
+
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
|
|
954
|
+
fprintf(stderr, "\u258F"); // 1/8 filled
|
|
955
|
+
} else {
|
|
956
|
+
fprintf(stderr, " ");
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
const int64_t batch_size = lm_ggml_opt_inputs(opt_ctx)->ne[1];
|
|
961
|
+
const int64_t idata = ibatch*batch_size;
|
|
962
|
+
const int64_t idata_max = ibatch_max*batch_size;
|
|
963
|
+
|
|
964
|
+
double loss;
|
|
965
|
+
double loss_unc;
|
|
966
|
+
lm_ggml_opt_result_loss(result, &loss, &loss_unc);
|
|
967
|
+
|
|
968
|
+
double accuracy;
|
|
969
|
+
double accuracy_unc;
|
|
970
|
+
lm_ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
|
|
971
|
+
|
|
972
|
+
const int64_t t_ibatch_us = lm_ggml_time_us() - t_start_us;
|
|
973
|
+
int64_t t_ibatch_s = t_ibatch_us / 1000000;
|
|
974
|
+
const int64_t t_ibatch_h = t_ibatch_s / 3600;
|
|
975
|
+
t_ibatch_s -= t_ibatch_h * 3600;
|
|
976
|
+
const int64_t t_ibatch_m = t_ibatch_s / 60;
|
|
977
|
+
t_ibatch_s -= t_ibatch_m * 60;
|
|
978
|
+
|
|
979
|
+
const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
|
|
980
|
+
int64_t t_eta_s = t_eta_us / 1000000;
|
|
981
|
+
const int64_t t_eta_h = t_eta_s / 3600;
|
|
982
|
+
t_eta_s -= t_eta_h * 3600;
|
|
983
|
+
const int64_t t_eta_m = t_eta_s / 60;
|
|
984
|
+
t_eta_s -= t_eta_m * 60;
|
|
985
|
+
|
|
986
|
+
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
|
|
987
|
+
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
|
|
988
|
+
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
|
|
989
|
+
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
|
|
990
|
+
if (ibatch == ibatch_max) {
|
|
991
|
+
fprintf(stderr, "\n");
|
|
992
|
+
}
|
|
993
|
+
fflush(stderr);
|
|
994
|
+
|
|
995
|
+
LM_GGML_UNUSED(dataset);
|
|
996
|
+
}
|
|
997
|
+
|
|
998
|
+
void lm_ggml_opt_fit(
|
|
999
|
+
lm_ggml_backend_sched_t backend_sched,
|
|
1000
|
+
lm_ggml_context * ctx_compute,
|
|
1001
|
+
lm_ggml_tensor * inputs,
|
|
1002
|
+
lm_ggml_tensor * outputs,
|
|
1003
|
+
lm_ggml_opt_dataset_t dataset,
|
|
1004
|
+
enum lm_ggml_opt_loss_type loss_type,
|
|
1005
|
+
enum lm_ggml_opt_optimizer_type optimizer,
|
|
1006
|
+
lm_ggml_opt_get_optimizer_params get_opt_pars,
|
|
1007
|
+
int64_t nepoch,
|
|
1008
|
+
int64_t nbatch_logical,
|
|
1009
|
+
float val_split,
|
|
1010
|
+
bool silent) {
|
|
1011
|
+
lm_ggml_time_init();
|
|
1012
|
+
const int64_t t_start_us = lm_ggml_time_us();
|
|
1013
|
+
|
|
1014
|
+
const int64_t ndata = lm_ggml_opt_dataset_data(dataset)->ne[1];
|
|
1015
|
+
const int64_t nbatch_physical = inputs->ne[1];
|
|
1016
|
+
LM_GGML_ASSERT(ndata % nbatch_logical == 0);
|
|
1017
|
+
LM_GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
|
|
1018
|
+
|
|
1019
|
+
const int64_t opt_period = nbatch_logical / nbatch_physical;
|
|
1020
|
+
const int64_t nbatches_logical = ndata / nbatch_logical;
|
|
1021
|
+
|
|
1022
|
+
LM_GGML_ASSERT(val_split >= 0.0f);
|
|
1023
|
+
LM_GGML_ASSERT(val_split < 1.0f);
|
|
1024
|
+
const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
|
|
1025
|
+
const int64_t idata_split = ibatch_split * nbatch_physical;
|
|
1026
|
+
|
|
1027
|
+
int64_t epoch = 1;
|
|
1028
|
+
|
|
1029
|
+
lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched, loss_type);
|
|
1030
|
+
params.ctx_compute = ctx_compute;
|
|
1031
|
+
params.inputs = inputs;
|
|
1032
|
+
params.outputs = outputs;
|
|
1033
|
+
params.opt_period = opt_period;
|
|
1034
|
+
params.get_opt_pars = get_opt_pars;
|
|
1035
|
+
params.get_opt_pars_ud = &epoch;
|
|
1036
|
+
params.optimizer = optimizer;
|
|
1037
|
+
lm_ggml_opt_context_t opt_ctx = lm_ggml_opt_init(params);
|
|
1038
|
+
|
|
1039
|
+
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
|
1040
|
+
if (nbatch_logical < ndata) {
|
|
1041
|
+
lm_ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
lm_ggml_opt_result_t result_train = lm_ggml_opt_result_init();
|
|
1045
|
+
lm_ggml_opt_result_t result_val = lm_ggml_opt_result_init();
|
|
1046
|
+
|
|
1047
|
+
lm_ggml_opt_epoch_callback epoch_callback = silent ? nullptr : lm_ggml_opt_epoch_callback_progress_bar;
|
|
1048
|
+
|
|
1049
|
+
for (; epoch <= nepoch; ++epoch) {
|
|
1050
|
+
if (nbatch_logical < idata_split) {
|
|
1051
|
+
lm_ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
lm_ggml_opt_result_reset(result_train);
|
|
1055
|
+
lm_ggml_opt_result_reset(result_val);
|
|
1056
|
+
|
|
1057
|
+
if (!silent) {
|
|
1058
|
+
fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
|
|
1059
|
+
}
|
|
1060
|
+
lm_ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
|
|
1061
|
+
if (!silent) {
|
|
1062
|
+
fprintf(stderr, "\n");
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
if (!silent) {
|
|
1067
|
+
int64_t t_total_s = (lm_ggml_time_us() - t_start_us) / 1000000;
|
|
1068
|
+
const int64_t t_total_h = t_total_s / 3600;
|
|
1069
|
+
t_total_s -= t_total_h * 3600;
|
|
1070
|
+
const int64_t t_total_m = t_total_s / 60;
|
|
1071
|
+
t_total_s -= t_total_m * 60;
|
|
1072
|
+
fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
lm_ggml_opt_free(opt_ctx);
|
|
1076
|
+
lm_ggml_opt_result_free(result_train);
|
|
1077
|
+
lm_ggml_opt_result_free(result_val);
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
enum lm_ggml_opt_optimizer_type lm_ggml_opt_context_optimizer_type(lm_ggml_opt_context_t c) {
|
|
1081
|
+
return c->optimizer;
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
LM_GGML_API const char * lm_ggml_opt_optimizer_name(enum lm_ggml_opt_optimizer_type o) {
|
|
1085
|
+
switch (o) {
|
|
1086
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
|
1087
|
+
return "adamw";
|
|
1088
|
+
case LM_GGML_OPT_OPTIMIZER_TYPE_SGD:
|
|
1089
|
+
return "sgd";
|
|
1090
|
+
default:
|
|
1091
|
+
return "undefined";
|
|
1092
|
+
};
|
|
1093
|
+
}
|