@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
|
|
|
28
28
|
};
|
|
29
29
|
|
|
30
30
|
struct ggml_opt_context {
|
|
31
|
-
ggml_backend_sched_t
|
|
32
|
-
ggml_cgraph
|
|
33
|
-
ggml_cgraph
|
|
34
|
-
struct ggml_context
|
|
35
|
-
struct ggml_context
|
|
36
|
-
struct ggml_context
|
|
37
|
-
struct ggml_context
|
|
38
|
-
ggml_backend_buffer_t
|
|
39
|
-
ggml_backend_buffer_t
|
|
40
|
-
std::mt19937
|
|
31
|
+
ggml_backend_sched_t backend_sched = nullptr;
|
|
32
|
+
ggml_cgraph * allocated_graph = nullptr;
|
|
33
|
+
ggml_cgraph * allocated_graph_copy = nullptr;
|
|
34
|
+
struct ggml_context * ctx_static = nullptr;
|
|
35
|
+
struct ggml_context * ctx_cpu = nullptr;
|
|
36
|
+
struct ggml_context * ctx_compute = nullptr;
|
|
37
|
+
struct ggml_context * ctx_copy = nullptr;
|
|
38
|
+
ggml_backend_buffer_t buf_static = nullptr;
|
|
39
|
+
ggml_backend_buffer_t buf_cpu = nullptr;
|
|
40
|
+
std::mt19937 rng;
|
|
41
|
+
enum ggml_opt_loss_type loss_type;
|
|
42
|
+
enum ggml_opt_build_type build_type;
|
|
43
|
+
enum ggml_opt_build_type build_type_alloc;
|
|
41
44
|
|
|
42
45
|
struct ggml_tensor * inputs = nullptr;
|
|
43
46
|
struct ggml_tensor * outputs = nullptr;
|
|
@@ -50,6 +53,11 @@ struct ggml_opt_context {
|
|
|
50
53
|
struct ggml_cgraph * gf = nullptr;
|
|
51
54
|
struct ggml_cgraph * gb_grad = nullptr;
|
|
52
55
|
struct ggml_cgraph * gb_opt = nullptr;
|
|
56
|
+
bool static_graphs = false;
|
|
57
|
+
bool eval_ready = false;
|
|
58
|
+
std::vector<struct ggml_tensor *> grad_accs;
|
|
59
|
+
std::vector<struct ggml_tensor *> grad_m;
|
|
60
|
+
std::vector<struct ggml_tensor *> grad_v;
|
|
53
61
|
|
|
54
62
|
int64_t iter = 1;
|
|
55
63
|
int32_t opt_period = 1;
|
|
@@ -73,7 +81,13 @@ struct ggml_opt_result {
|
|
|
73
81
|
|
|
74
82
|
// ====== Dataset ======
|
|
75
83
|
|
|
76
|
-
ggml_opt_dataset_t ggml_opt_dataset_init(
|
|
84
|
+
ggml_opt_dataset_t ggml_opt_dataset_init(
|
|
85
|
+
enum ggml_type type_data,
|
|
86
|
+
enum ggml_type type_label,
|
|
87
|
+
int64_t ne_datapoint,
|
|
88
|
+
int64_t ne_label,
|
|
89
|
+
int64_t ndata,
|
|
90
|
+
int64_t ndata_shard) {
|
|
77
91
|
GGML_ASSERT(ne_datapoint > 0);
|
|
78
92
|
GGML_ASSERT(ne_label >= 0);
|
|
79
93
|
GGML_ASSERT(ndata > 0);
|
|
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
|
|
|
92
106
|
result->ctx = ggml_init(params);
|
|
93
107
|
}
|
|
94
108
|
|
|
95
|
-
result->data = ggml_new_tensor_2d(result->ctx,
|
|
109
|
+
result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
|
|
96
110
|
result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
|
|
97
111
|
|
|
98
112
|
if (ne_label > 0) {
|
|
99
|
-
result->labels = ggml_new_tensor_2d(result->ctx,
|
|
113
|
+
result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
|
|
100
114
|
result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
|
|
101
115
|
} else {
|
|
102
116
|
result->labels = nullptr;
|
|
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
|
|
|
119
133
|
delete dataset;
|
|
120
134
|
}
|
|
121
135
|
|
|
136
|
+
int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
|
|
137
|
+
return dataset->ndata;
|
|
138
|
+
}
|
|
139
|
+
|
|
122
140
|
struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
|
|
123
141
|
return dataset->data;
|
|
124
142
|
}
|
|
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
|
|
144
162
|
GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
|
|
145
163
|
GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
|
|
146
164
|
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
|
165
|
+
GGML_ASSERT( data_batch->type == dataset->data->type);
|
|
166
|
+
GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
|
|
147
167
|
|
|
148
168
|
const size_t nb_data_batch = ggml_nbytes(data_batch);
|
|
149
169
|
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
|
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
|
|
|
171
191
|
}
|
|
172
192
|
}
|
|
173
193
|
|
|
194
|
+
void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
|
|
195
|
+
GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
|
196
|
+
GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
|
197
|
+
|
|
198
|
+
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
|
199
|
+
|
|
200
|
+
GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
|
201
|
+
|
|
202
|
+
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
|
203
|
+
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
|
204
|
+
|
|
205
|
+
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
|
|
206
|
+
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
|
|
207
|
+
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
|
|
208
|
+
|
|
209
|
+
if (!labels_batch) {
|
|
210
|
+
continue;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
|
|
214
|
+
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
|
|
215
|
+
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
174
219
|
// ====== Model / Context ======
|
|
175
220
|
|
|
176
221
|
struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
|
|
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
|
|
187
232
|
return result;
|
|
188
233
|
}
|
|
189
234
|
|
|
235
|
+
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
|
236
|
+
return *((struct ggml_opt_optimizer_params *) userdata);
|
|
237
|
+
}
|
|
238
|
+
|
|
190
239
|
struct ggml_opt_params ggml_opt_default_params(
|
|
191
240
|
ggml_backend_sched_t backend_sched,
|
|
192
|
-
struct ggml_context * ctx_compute,
|
|
193
|
-
struct ggml_tensor * inputs,
|
|
194
|
-
struct ggml_tensor * outputs,
|
|
195
241
|
enum ggml_opt_loss_type loss_type) {
|
|
196
242
|
return {
|
|
197
243
|
/*backend_sched =*/ backend_sched,
|
|
198
|
-
/*ctx_compute =*/
|
|
199
|
-
/*inputs =*/
|
|
200
|
-
/*logits =*/
|
|
244
|
+
/*ctx_compute =*/ nullptr,
|
|
245
|
+
/*inputs =*/ nullptr,
|
|
246
|
+
/*logits =*/ nullptr,
|
|
201
247
|
/*loss_type =*/ loss_type,
|
|
202
248
|
/*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
|
|
203
249
|
/*opt_period =*/ 1,
|
|
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
|
|
|
266
312
|
return dst;
|
|
267
313
|
}
|
|
268
314
|
|
|
269
|
-
static void
|
|
270
|
-
GGML_ASSERT(
|
|
271
|
-
|
|
272
|
-
return;
|
|
273
|
-
}
|
|
274
|
-
|
|
275
|
-
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
|
276
|
-
|
|
277
|
-
{
|
|
278
|
-
ggml_init_params params = {
|
|
279
|
-
/*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
|
|
280
|
-
/*.mem_buffer =*/ nullptr,
|
|
281
|
-
/*.no_alloc =*/ true,
|
|
282
|
-
};
|
|
283
|
-
ggml_free(opt_ctx->ctx_copy);
|
|
284
|
-
opt_ctx->ctx_copy = ggml_init(params);
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
|
288
|
-
|
|
289
|
-
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
290
|
-
opt_ctx->allocated_graph = graph;
|
|
291
|
-
}
|
|
292
|
-
|
|
293
|
-
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
294
|
-
ggml_opt_context_t result = new struct ggml_opt_context;
|
|
295
|
-
result->backend_sched = params.backend_sched;
|
|
296
|
-
result->ctx_compute = params.ctx_compute;
|
|
297
|
-
result->inputs = params.inputs;
|
|
298
|
-
result->outputs = params.outputs;
|
|
299
|
-
result->opt_period = params.opt_period;
|
|
300
|
-
result->get_opt_pars = params.get_opt_pars;
|
|
301
|
-
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
302
|
-
|
|
303
|
-
GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
|
|
304
|
-
GGML_ASSERT(result->opt_period >= 1);
|
|
305
|
-
|
|
306
|
-
const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
|
|
307
|
-
(params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
|
|
315
|
+
static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
|
316
|
+
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
|
317
|
+
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
|
308
318
|
|
|
309
|
-
|
|
310
|
-
|
|
319
|
+
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
|
320
|
+
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
|
311
321
|
|
|
312
|
-
|
|
313
|
-
|
|
322
|
+
ggml_set_input(opt_ctx->inputs);
|
|
323
|
+
ggml_set_output(opt_ctx->outputs);
|
|
314
324
|
|
|
315
325
|
int n_param = 0;
|
|
316
|
-
for (int i = 0; i <
|
|
317
|
-
|
|
326
|
+
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
|
|
327
|
+
const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
328
|
+
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
|
318
329
|
n_param++;
|
|
319
330
|
}
|
|
331
|
+
GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
|
|
320
332
|
}
|
|
321
333
|
|
|
322
|
-
{
|
|
334
|
+
if (!opt_ctx->ctx_static) {
|
|
323
335
|
// The static context is used for:
|
|
324
|
-
// - gradients (1 tensor per param if using gradient accumulation)
|
|
336
|
+
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
|
|
325
337
|
// - optimizer momenta (2 tensors per param)
|
|
326
|
-
// - labels
|
|
327
|
-
// - loss
|
|
328
|
-
// - pred
|
|
329
|
-
// - ncorrect (2 tensors).
|
|
330
|
-
|
|
331
|
-
const size_t
|
|
338
|
+
// - labels (if using static graphs)
|
|
339
|
+
// - loss (if using static graphs, up to 5 tensors)
|
|
340
|
+
// - pred (if using static graphs)
|
|
341
|
+
// - ncorrect (if using static graphs, 2 tensors).
|
|
342
|
+
constexpr size_t n_loss = 1;
|
|
343
|
+
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
|
344
|
+
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
|
345
|
+
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
|
346
|
+
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
|
332
347
|
struct ggml_init_params params = {
|
|
333
348
|
/*.mem_size =*/ size_meta,
|
|
334
349
|
/*.mem_buffer =*/ nullptr,
|
|
335
350
|
/*.no_alloc =*/ true,
|
|
336
351
|
};
|
|
337
|
-
|
|
352
|
+
opt_ctx->ctx_static = ggml_init(params);
|
|
338
353
|
}
|
|
354
|
+
GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
|
|
355
|
+
|
|
339
356
|
{
|
|
340
|
-
// The
|
|
341
|
-
//
|
|
357
|
+
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
|
|
358
|
+
// It is used for:
|
|
359
|
+
// - optimizer parameters (1 shared for all optimizer invocations)
|
|
342
360
|
const size_t size_meta = 1 * ggml_tensor_overhead();
|
|
343
361
|
struct ggml_init_params params = {
|
|
344
362
|
/*.mem_size =*/ size_meta,
|
|
345
363
|
/*.mem_buffer =*/ nullptr,
|
|
346
364
|
/*.no_alloc =*/ true,
|
|
347
365
|
};
|
|
348
|
-
|
|
366
|
+
ggml_free(opt_ctx->ctx_cpu);
|
|
367
|
+
opt_ctx->ctx_cpu = ggml_init(params);
|
|
368
|
+
|
|
369
|
+
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
|
370
|
+
opt_ctx->buf_cpu = nullptr;
|
|
349
371
|
}
|
|
350
372
|
|
|
373
|
+
struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
|
|
351
374
|
|
|
352
|
-
switch (
|
|
375
|
+
switch (opt_ctx->loss_type) {
|
|
353
376
|
case GGML_OPT_LOSS_TYPE_MEAN: {
|
|
354
|
-
|
|
355
|
-
ggml_set_name(
|
|
356
|
-
const float scale = 1.0f / (
|
|
357
|
-
|
|
358
|
-
ggml_set_name(
|
|
359
|
-
|
|
377
|
+
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
|
378
|
+
ggml_set_name(opt_ctx->loss, "loss_sum");
|
|
379
|
+
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
|
380
|
+
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
|
381
|
+
ggml_set_name(opt_ctx->loss, "loss_mean");
|
|
382
|
+
opt_ctx->loss_per_datapoint = true;
|
|
360
383
|
break;
|
|
361
384
|
}
|
|
362
385
|
case GGML_OPT_LOSS_TYPE_SUM: {
|
|
363
|
-
|
|
364
|
-
ggml_set_name(
|
|
365
|
-
|
|
386
|
+
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
|
|
387
|
+
ggml_set_name(opt_ctx->loss, "loss_sum");
|
|
388
|
+
opt_ctx->loss_per_datapoint = false;
|
|
366
389
|
break;
|
|
367
390
|
}
|
|
368
391
|
case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
|
|
369
|
-
|
|
370
|
-
ggml_set_input(
|
|
371
|
-
ggml_set_name(
|
|
372
|
-
|
|
373
|
-
ggml_set_name(
|
|
374
|
-
if (
|
|
375
|
-
|
|
376
|
-
ggml_set_name(
|
|
392
|
+
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
|
393
|
+
ggml_set_input(opt_ctx->labels);
|
|
394
|
+
ggml_set_name(opt_ctx->labels, "labels");
|
|
395
|
+
opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
|
396
|
+
ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
|
|
397
|
+
if (opt_ctx->opt_period > 1) {
|
|
398
|
+
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
|
|
399
|
+
ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
|
|
377
400
|
}
|
|
378
|
-
|
|
401
|
+
opt_ctx->loss_per_datapoint = true;
|
|
379
402
|
break;
|
|
380
403
|
}
|
|
381
404
|
case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
|
|
382
|
-
|
|
383
|
-
ggml_set_input(
|
|
384
|
-
ggml_set_name(
|
|
385
|
-
|
|
386
|
-
ggml_set_name(
|
|
387
|
-
|
|
388
|
-
ggml_set_name(
|
|
389
|
-
|
|
390
|
-
ggml_set_name(
|
|
391
|
-
const float scale = 1.0f / (
|
|
392
|
-
|
|
393
|
-
ggml_set_name(
|
|
394
|
-
|
|
405
|
+
opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
|
406
|
+
ggml_set_input(opt_ctx->labels);
|
|
407
|
+
ggml_set_name(opt_ctx->labels, "labels");
|
|
408
|
+
opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
|
409
|
+
ggml_set_name(opt_ctx->loss, "loss_error");
|
|
410
|
+
opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
|
|
411
|
+
ggml_set_name(opt_ctx->loss, "loss_squared_error");
|
|
412
|
+
opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
|
|
413
|
+
ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
|
|
414
|
+
const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
|
|
415
|
+
opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
|
|
416
|
+
ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
|
|
417
|
+
opt_ctx->loss_per_datapoint = true;
|
|
395
418
|
break;
|
|
396
419
|
}
|
|
397
420
|
}
|
|
398
|
-
ggml_set_output(
|
|
399
|
-
ggml_set_loss(
|
|
400
|
-
ggml_build_forward_expand(
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
421
|
+
ggml_set_output(opt_ctx->loss);
|
|
422
|
+
ggml_set_loss(opt_ctx->loss);
|
|
423
|
+
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
|
|
424
|
+
|
|
425
|
+
if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
|
|
426
|
+
opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
|
|
427
|
+
ggml_set_name(opt_ctx->pred, "pred");
|
|
428
|
+
ggml_set_output(opt_ctx->pred);
|
|
429
|
+
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
|
|
430
|
+
|
|
431
|
+
opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
|
|
432
|
+
ggml_set_name(opt_ctx->ncorrect, "ncorrect");
|
|
433
|
+
ggml_set_output(opt_ctx->ncorrect);
|
|
434
|
+
ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
|
|
435
|
+
}
|
|
406
436
|
|
|
407
|
-
if (
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
437
|
+
if (opt_ctx->buf_static) {
|
|
438
|
+
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
|
439
|
+
return;
|
|
440
|
+
}
|
|
441
|
+
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
|
|
442
|
+
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
|
443
|
+
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
444
|
+
return;
|
|
414
445
|
}
|
|
415
446
|
|
|
416
|
-
if (
|
|
417
|
-
|
|
418
|
-
|
|
447
|
+
if (opt_ctx->grad_accs.empty()) {
|
|
448
|
+
GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
|
|
449
|
+
|
|
450
|
+
const int n_nodes = opt_ctx->gf->n_nodes;
|
|
451
|
+
opt_ctx->grad_accs.resize(n_nodes);
|
|
452
|
+
for (int i = 0; i < n_nodes; ++i) {
|
|
453
|
+
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
454
|
+
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
|
455
|
+
opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
|
456
|
+
} else {
|
|
457
|
+
opt_ctx->grad_accs[i] = nullptr;
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
|
462
|
+
opt_ctx->grad_m.resize(n_nodes);
|
|
463
|
+
opt_ctx->grad_v.resize(n_nodes);
|
|
464
|
+
for (int i = 0; i < n_nodes; ++i) {
|
|
465
|
+
ggml_tensor * node = opt_ctx->gf->nodes[i];
|
|
466
|
+
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
|
467
|
+
opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
|
468
|
+
opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
|
|
469
|
+
} else {
|
|
470
|
+
opt_ctx->grad_m[i] = nullptr;
|
|
471
|
+
opt_ctx->grad_v[i] = nullptr;
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
}
|
|
419
475
|
}
|
|
420
476
|
|
|
421
477
|
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
|
422
|
-
|
|
423
|
-
ggml_build_backward_expand(
|
|
478
|
+
opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
|
479
|
+
ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
|
424
480
|
|
|
425
|
-
if (
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
481
|
+
if (opt_ctx->buf_static) {
|
|
482
|
+
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
|
483
|
+
return;
|
|
484
|
+
}
|
|
485
|
+
} else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
|
|
486
|
+
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
487
|
+
ggml_graph_reset(opt_ctx->gb_grad);
|
|
429
488
|
}
|
|
430
489
|
|
|
431
|
-
GGML_ASSERT(
|
|
490
|
+
GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
|
|
432
491
|
|
|
433
492
|
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
|
434
|
-
|
|
493
|
+
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
|
435
494
|
|
|
436
|
-
|
|
437
|
-
ggml_set_input(
|
|
438
|
-
ggml_set_name(
|
|
495
|
+
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
|
496
|
+
ggml_set_input(opt_ctx->adamw_params);
|
|
497
|
+
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
|
439
498
|
|
|
440
|
-
for (int i =
|
|
441
|
-
struct ggml_tensor * node =
|
|
442
|
-
struct ggml_tensor * grad = ggml_graph_get_grad(
|
|
499
|
+
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
|
500
|
+
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
|
501
|
+
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
|
443
502
|
|
|
444
|
-
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
|
445
|
-
struct ggml_tensor * m =
|
|
446
|
-
struct ggml_tensor * v =
|
|
447
|
-
struct ggml_tensor * opt_step = ggml_opt_step_adamw(
|
|
448
|
-
|
|
503
|
+
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
|
504
|
+
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
|
505
|
+
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
|
506
|
+
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
|
507
|
+
|
|
508
|
+
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
|
509
|
+
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
|
510
|
+
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
|
511
|
+
|
|
512
|
+
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
|
449
513
|
}
|
|
450
514
|
}
|
|
451
515
|
|
|
452
|
-
|
|
453
|
-
|
|
516
|
+
if (!opt_ctx->buf_static) {
|
|
517
|
+
opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
|
|
518
|
+
opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
|
519
|
+
ggml_graph_reset(opt_ctx->gb_opt);
|
|
520
|
+
}
|
|
454
521
|
|
|
455
|
-
|
|
522
|
+
opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
|
|
523
|
+
}
|
|
456
524
|
|
|
457
|
-
|
|
525
|
+
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|
526
|
+
ggml_opt_context_t result = new struct ggml_opt_context;
|
|
527
|
+
result->backend_sched = params.backend_sched;
|
|
528
|
+
result->ctx_compute = params.ctx_compute;
|
|
529
|
+
result->loss_type = params.loss_type;
|
|
530
|
+
result->build_type = params.build_type;
|
|
531
|
+
result->build_type_alloc = params.build_type;
|
|
532
|
+
result->inputs = params.inputs;
|
|
533
|
+
result->outputs = params.outputs;
|
|
534
|
+
result->opt_period = params.opt_period;
|
|
535
|
+
result->get_opt_pars = params.get_opt_pars;
|
|
536
|
+
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
|
537
|
+
|
|
538
|
+
GGML_ASSERT(result->opt_period >= 1);
|
|
539
|
+
|
|
540
|
+
result->static_graphs = result->ctx_compute;
|
|
541
|
+
|
|
542
|
+
if (!result->static_graphs) {
|
|
543
|
+
GGML_ASSERT(!result->inputs);
|
|
544
|
+
GGML_ASSERT(!result->outputs);
|
|
545
|
+
return result;
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
GGML_ASSERT(result->inputs);
|
|
549
|
+
GGML_ASSERT(result->outputs);
|
|
550
|
+
|
|
551
|
+
result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
|
552
|
+
ggml_build_forward_expand(result->gf, result->outputs);
|
|
553
|
+
|
|
554
|
+
ggml_opt_build(result);
|
|
458
555
|
|
|
459
556
|
return result;
|
|
460
557
|
}
|
|
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
|
|
|
464
561
|
return;
|
|
465
562
|
}
|
|
466
563
|
ggml_backend_buffer_free(opt_ctx->buf_static);
|
|
467
|
-
ggml_backend_buffer_free(opt_ctx->
|
|
564
|
+
ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
|
468
565
|
ggml_free(opt_ctx->ctx_static);
|
|
469
|
-
ggml_free(opt_ctx->
|
|
566
|
+
ggml_free(opt_ctx->ctx_cpu);
|
|
470
567
|
delete opt_ctx;
|
|
471
568
|
}
|
|
472
569
|
|
|
@@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
|
|
|
582
679
|
|
|
583
680
|
// ====== Computation ======
|
|
584
681
|
|
|
585
|
-
|
|
586
|
-
|
|
682
|
+
void ggml_opt_prepare_alloc(
|
|
683
|
+
ggml_opt_context_t opt_ctx,
|
|
684
|
+
struct ggml_context * ctx_compute,
|
|
685
|
+
struct ggml_cgraph * gf,
|
|
686
|
+
struct ggml_tensor * inputs,
|
|
687
|
+
struct ggml_tensor * outputs) {
|
|
688
|
+
GGML_ASSERT(!opt_ctx->static_graphs);
|
|
689
|
+
opt_ctx->ctx_compute = ctx_compute;
|
|
690
|
+
opt_ctx->gf = gf;
|
|
691
|
+
opt_ctx->inputs = inputs;
|
|
692
|
+
opt_ctx->outputs = outputs;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
|
696
|
+
GGML_ASSERT(!opt_ctx->eval_ready);
|
|
697
|
+
if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
|
|
698
|
+
ggml_graph_reset(opt_ctx->gb_grad);
|
|
699
|
+
}
|
|
700
|
+
if (backward) {
|
|
701
|
+
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
|
702
|
+
opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
|
|
703
|
+
} else {
|
|
704
|
+
opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
if (!opt_ctx->static_graphs) {
|
|
708
|
+
ggml_opt_build(opt_ctx);
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
struct ggml_cgraph * graph = nullptr;
|
|
712
|
+
switch (opt_ctx->build_type) {
|
|
713
|
+
case GGML_OPT_BUILD_TYPE_FORWARD: {
|
|
714
|
+
graph = opt_ctx->gf;
|
|
715
|
+
} break;
|
|
716
|
+
case GGML_OPT_BUILD_TYPE_GRAD: {
|
|
717
|
+
graph = opt_ctx->gb_grad;
|
|
718
|
+
} break;
|
|
719
|
+
case GGML_OPT_BUILD_TYPE_OPT: {
|
|
720
|
+
graph = opt_ctx->gb_opt;
|
|
721
|
+
} break;
|
|
722
|
+
}
|
|
723
|
+
GGML_ASSERT(graph);
|
|
724
|
+
|
|
725
|
+
if (opt_ctx->allocated_graph == graph) {
|
|
726
|
+
opt_ctx->eval_ready = true;
|
|
727
|
+
return;
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
|
731
|
+
|
|
732
|
+
if (opt_ctx->static_graphs) {
|
|
733
|
+
ggml_init_params params = {
|
|
734
|
+
/*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
|
|
735
|
+
/*.mem_buffer =*/ nullptr,
|
|
736
|
+
/*.no_alloc =*/ true,
|
|
737
|
+
};
|
|
738
|
+
ggml_free(opt_ctx->ctx_copy);
|
|
739
|
+
opt_ctx->ctx_copy = ggml_init(params);
|
|
740
|
+
|
|
741
|
+
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
|
742
|
+
} else {
|
|
743
|
+
opt_ctx->allocated_graph_copy = graph;
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
747
|
+
opt_ctx->allocated_graph = graph;
|
|
748
|
+
|
|
749
|
+
opt_ctx->eval_ready = true;
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
|
753
|
+
GGML_ASSERT(opt_ctx->eval_ready);
|
|
754
|
+
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
|
587
755
|
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
|
588
756
|
|
|
589
757
|
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
|
@@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
|
|
609
777
|
adamw_par_data[6] = beta2h;
|
|
610
778
|
}
|
|
611
779
|
|
|
612
|
-
ggml_opt_alloc_graph(opt_ctx, graph);
|
|
613
780
|
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
|
614
781
|
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
|
|
782
|
+
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
|
783
|
+
|
|
784
|
+
if (!opt_ctx->static_graphs) {
|
|
785
|
+
opt_ctx->gf = nullptr;
|
|
786
|
+
opt_ctx->gb_grad = nullptr;
|
|
787
|
+
opt_ctx->gb_opt = nullptr;
|
|
788
|
+
opt_ctx->allocated_graph = nullptr;
|
|
789
|
+
opt_ctx->allocated_graph_copy = nullptr;
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
opt_ctx->eval_ready = false;
|
|
615
793
|
|
|
616
794
|
if (!result) {
|
|
617
795
|
return;
|
|
@@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
|
|
635
813
|
ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
|
|
636
814
|
result->loss.push_back(loss);
|
|
637
815
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
816
|
+
if (opt_ctx->pred) {
|
|
817
|
+
GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
|
|
818
|
+
std::vector<int32_t> pred(ndata);
|
|
819
|
+
ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
|
|
820
|
+
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
|
821
|
+
}
|
|
642
822
|
|
|
643
|
-
if (!opt_ctx->
|
|
823
|
+
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
|
|
644
824
|
result->ncorrect = -1;
|
|
645
825
|
return;
|
|
646
826
|
}
|
|
@@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
|
|
|
652
832
|
result->ncorrect += ncorrect;
|
|
653
833
|
}
|
|
654
834
|
|
|
655
|
-
void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
|
656
|
-
ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
|
|
660
|
-
if (opt_ctx->opt_period == 1) {
|
|
661
|
-
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
|
662
|
-
return;
|
|
663
|
-
}
|
|
664
|
-
|
|
665
|
-
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
|
666
|
-
if (opt_i_next == 0) {
|
|
667
|
-
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
|
668
|
-
ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
|
|
669
|
-
} else {
|
|
670
|
-
ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
|
|
671
|
-
}
|
|
672
|
-
opt_ctx->opt_i = opt_i_next;
|
|
673
|
-
}
|
|
674
|
-
|
|
675
835
|
// ====== High-Level Functions ======
|
|
676
836
|
|
|
677
837
|
void ggml_opt_epoch(
|
|
@@ -700,16 +860,18 @@ void ggml_opt_epoch(
|
|
|
700
860
|
int64_t ibatch = 0;
|
|
701
861
|
int64_t t_loop_start = ggml_time_us();
|
|
702
862
|
for (; ibatch < ibatch_split; ++ibatch) {
|
|
863
|
+
ggml_opt_alloc(opt_ctx, /*backward =*/ true);
|
|
703
864
|
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
|
704
|
-
|
|
865
|
+
ggml_opt_eval(opt_ctx, result_train);
|
|
705
866
|
if (callback_train) {
|
|
706
867
|
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
|
|
707
868
|
}
|
|
708
869
|
}
|
|
709
870
|
t_loop_start = ggml_time_us();
|
|
710
871
|
for (; ibatch < nbatches; ++ibatch) {
|
|
872
|
+
ggml_opt_alloc(opt_ctx, /*backward =*/ false);
|
|
711
873
|
ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
|
712
|
-
|
|
874
|
+
ggml_opt_eval(opt_ctx, result_eval);
|
|
713
875
|
if (callback_eval) {
|
|
714
876
|
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
|
|
715
877
|
}
|
|
@@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
|
|
|
726
888
|
int64_t t_start_us) {
|
|
727
889
|
fprintf(stderr, "%s[", train ? "train: " : "val: ");
|
|
728
890
|
|
|
729
|
-
|
|
891
|
+
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
|
|
892
|
+
constexpr int64_t bar_length = 8;
|
|
893
|
+
const int64_t ibatch8 = 8 * ibatch;
|
|
730
894
|
for (int64_t j = 0; j < bar_length; ++j) {
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
895
|
+
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
|
|
896
|
+
fprintf(stderr, "\u2588"); // full block
|
|
897
|
+
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
|
|
898
|
+
fprintf(stderr, "\u2589"); // 7/8 filled
|
|
899
|
+
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
|
|
900
|
+
fprintf(stderr, "\u258A"); // 6/8 filled
|
|
901
|
+
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
|
|
902
|
+
fprintf(stderr, "\u258B"); // 5/8 filled
|
|
903
|
+
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
|
|
904
|
+
fprintf(stderr, "\u258C"); // 4/8 filled
|
|
905
|
+
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
|
|
906
|
+
fprintf(stderr, "\u258D"); // 3/8 filled
|
|
907
|
+
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
|
|
908
|
+
fprintf(stderr, "\u258E"); // 2/8 filled
|
|
909
|
+
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
|
|
910
|
+
fprintf(stderr, "\u258F"); // 1/8 filled
|
|
736
911
|
} else {
|
|
737
912
|
fprintf(stderr, " ");
|
|
738
913
|
}
|
|
@@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
|
|
|
764
939
|
const int64_t t_eta_m = t_eta_s / 60;
|
|
765
940
|
t_eta_s -= t_eta_m * 60;
|
|
766
941
|
|
|
767
|
-
fprintf(stderr, "
|
|
768
|
-
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "
|
|
942
|
+
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
|
|
943
|
+
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
|
|
769
944
|
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
|
|
770
945
|
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
|
|
771
946
|
if (ibatch == ibatch_max) {
|
|
@@ -806,7 +981,10 @@ void ggml_opt_fit(
|
|
|
806
981
|
|
|
807
982
|
int64_t epoch = 1;
|
|
808
983
|
|
|
809
|
-
ggml_opt_params params = ggml_opt_default_params(backend_sched,
|
|
984
|
+
ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
|
|
985
|
+
params.ctx_compute = ctx_compute;
|
|
986
|
+
params.inputs = inputs;
|
|
987
|
+
params.outputs = outputs;
|
|
810
988
|
params.opt_period = opt_period;
|
|
811
989
|
params.get_opt_pars = get_opt_pars;
|
|
812
990
|
params.get_opt_pars_ud = &epoch;
|