cui-llama.rn 1.6.1 → 1.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
- package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +38 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/ggml-opt.cpp
CHANGED
@@ -28,16 +28,19 @@ struct lm_ggml_opt_dataset {
|
|
28
28
|
};
|
29
29
|
|
30
30
|
struct lm_ggml_opt_context {
|
31
|
-
lm_ggml_backend_sched_t
|
32
|
-
lm_ggml_cgraph
|
33
|
-
lm_ggml_cgraph
|
34
|
-
struct lm_ggml_context
|
35
|
-
struct lm_ggml_context
|
36
|
-
struct lm_ggml_context
|
37
|
-
struct lm_ggml_context
|
38
|
-
lm_ggml_backend_buffer_t
|
39
|
-
lm_ggml_backend_buffer_t
|
40
|
-
std::mt19937
|
31
|
+
lm_ggml_backend_sched_t backend_sched = nullptr;
|
32
|
+
lm_ggml_cgraph * allocated_graph = nullptr;
|
33
|
+
lm_ggml_cgraph * allocated_graph_copy = nullptr;
|
34
|
+
struct lm_ggml_context * ctx_static = nullptr;
|
35
|
+
struct lm_ggml_context * ctx_cpu = nullptr;
|
36
|
+
struct lm_ggml_context * ctx_compute = nullptr;
|
37
|
+
struct lm_ggml_context * ctx_copy = nullptr;
|
38
|
+
lm_ggml_backend_buffer_t buf_static = nullptr;
|
39
|
+
lm_ggml_backend_buffer_t buf_cpu = nullptr;
|
40
|
+
std::mt19937 rng;
|
41
|
+
enum lm_ggml_opt_loss_type loss_type;
|
42
|
+
enum lm_ggml_opt_build_type build_type;
|
43
|
+
enum lm_ggml_opt_build_type build_type_alloc;
|
41
44
|
|
42
45
|
struct lm_ggml_tensor * inputs = nullptr;
|
43
46
|
struct lm_ggml_tensor * outputs = nullptr;
|
@@ -50,6 +53,11 @@ struct lm_ggml_opt_context {
|
|
50
53
|
struct lm_ggml_cgraph * gf = nullptr;
|
51
54
|
struct lm_ggml_cgraph * gb_grad = nullptr;
|
52
55
|
struct lm_ggml_cgraph * gb_opt = nullptr;
|
56
|
+
bool static_graphs = false;
|
57
|
+
bool eval_ready = false;
|
58
|
+
std::vector<struct lm_ggml_tensor *> grad_accs;
|
59
|
+
std::vector<struct lm_ggml_tensor *> grad_m;
|
60
|
+
std::vector<struct lm_ggml_tensor *> grad_v;
|
53
61
|
|
54
62
|
int64_t iter = 1;
|
55
63
|
int32_t opt_period = 1;
|
@@ -73,7 +81,13 @@ struct lm_ggml_opt_result {
|
|
73
81
|
|
74
82
|
// ====== Dataset ======
|
75
83
|
|
76
|
-
lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
|
84
|
+
lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
|
85
|
+
enum lm_ggml_type type_data,
|
86
|
+
enum lm_ggml_type type_label,
|
87
|
+
int64_t ne_datapoint,
|
88
|
+
int64_t ne_label,
|
89
|
+
int64_t ndata,
|
90
|
+
int64_t ndata_shard) {
|
77
91
|
LM_GGML_ASSERT(ne_datapoint > 0);
|
78
92
|
LM_GGML_ASSERT(ne_label >= 0);
|
79
93
|
LM_GGML_ASSERT(ndata > 0);
|
@@ -92,11 +106,11 @@ lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_
|
|
92
106
|
result->ctx = lm_ggml_init(params);
|
93
107
|
}
|
94
108
|
|
95
|
-
result->data = lm_ggml_new_tensor_2d(result->ctx,
|
109
|
+
result->data = lm_ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
|
96
110
|
result->nbs_data = lm_ggml_nbytes(result->data) * ndata_shard/ndata;
|
97
111
|
|
98
112
|
if (ne_label > 0) {
|
99
|
-
result->labels = lm_ggml_new_tensor_2d(result->ctx,
|
113
|
+
result->labels = lm_ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
|
100
114
|
result->nbs_labels = lm_ggml_nbytes(result->labels) * ndata_shard/ndata;
|
101
115
|
} else {
|
102
116
|
result->labels = nullptr;
|
@@ -119,6 +133,10 @@ void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset) {
|
|
119
133
|
delete dataset;
|
120
134
|
}
|
121
135
|
|
136
|
+
int64_t lm_ggml_opt_dataset_ndata(lm_ggml_opt_dataset_t dataset) {
|
137
|
+
return dataset->ndata;
|
138
|
+
}
|
139
|
+
|
122
140
|
struct lm_ggml_tensor * lm_ggml_opt_dataset_data(lm_ggml_opt_dataset_t dataset) {
|
123
141
|
return dataset->data;
|
124
142
|
}
|
@@ -144,6 +162,8 @@ void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml
|
|
144
162
|
LM_GGML_ASSERT( data_batch && lm_ggml_is_contiguous(data_batch));
|
145
163
|
LM_GGML_ASSERT(!labels_batch || lm_ggml_is_contiguous(labels_batch));
|
146
164
|
LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
165
|
+
LM_GGML_ASSERT( data_batch->type == dataset->data->type);
|
166
|
+
LM_GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
|
147
167
|
|
148
168
|
const size_t nb_data_batch = lm_ggml_nbytes(data_batch);
|
149
169
|
LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
@@ -171,6 +191,31 @@ void lm_ggml_opt_dataset_get_batch(lm_ggml_opt_dataset_t dataset, struct lm_ggml
|
|
171
191
|
}
|
172
192
|
}
|
173
193
|
|
194
|
+
void lm_ggml_opt_dataset_get_batch_host(lm_ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
|
195
|
+
LM_GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
|
196
|
+
LM_GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
|
197
|
+
|
198
|
+
const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
|
199
|
+
|
200
|
+
LM_GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
|
201
|
+
|
202
|
+
for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
|
203
|
+
const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
|
204
|
+
|
205
|
+
const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data;
|
206
|
+
char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data;
|
207
|
+
memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
|
208
|
+
|
209
|
+
if (!labels_batch) {
|
210
|
+
continue;
|
211
|
+
}
|
212
|
+
|
213
|
+
const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels;
|
214
|
+
char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels;
|
215
|
+
memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
|
216
|
+
}
|
217
|
+
}
|
218
|
+
|
174
219
|
// ====== Model / Context ======
|
175
220
|
|
176
221
|
struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata) {
|
@@ -187,17 +232,18 @@ struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(voi
|
|
187
232
|
return result;
|
188
233
|
}
|
189
234
|
|
235
|
+
struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_constant_optimizer_params(void * userdata) {
|
236
|
+
return *((struct lm_ggml_opt_optimizer_params *) userdata);
|
237
|
+
}
|
238
|
+
|
190
239
|
struct lm_ggml_opt_params lm_ggml_opt_default_params(
|
191
240
|
lm_ggml_backend_sched_t backend_sched,
|
192
|
-
struct lm_ggml_context * ctx_compute,
|
193
|
-
struct lm_ggml_tensor * inputs,
|
194
|
-
struct lm_ggml_tensor * outputs,
|
195
241
|
enum lm_ggml_opt_loss_type loss_type) {
|
196
242
|
return {
|
197
243
|
/*backend_sched =*/ backend_sched,
|
198
|
-
/*ctx_compute =*/
|
199
|
-
/*inputs =*/
|
200
|
-
/*logits =*/
|
244
|
+
/*ctx_compute =*/ nullptr,
|
245
|
+
/*inputs =*/ nullptr,
|
246
|
+
/*logits =*/ nullptr,
|
201
247
|
/*loss_type =*/ loss_type,
|
202
248
|
/*build_type =*/ LM_GGML_OPT_BUILD_TYPE_OPT,
|
203
249
|
/*opt_period =*/ 1,
|
@@ -266,195 +312,246 @@ static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * src) {
|
|
266
312
|
return dst;
|
267
313
|
}
|
268
314
|
|
269
|
-
static void
|
270
|
-
LM_GGML_ASSERT(
|
271
|
-
|
272
|
-
return;
|
273
|
-
}
|
315
|
+
static void lm_ggml_opt_build(lm_ggml_opt_context_t opt_ctx) {
|
316
|
+
LM_GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with lm_ggml_opt_prepare_alloc");
|
317
|
+
LM_GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
274
318
|
|
275
|
-
|
319
|
+
const bool accumulate = opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD &&
|
320
|
+
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
276
321
|
|
277
|
-
|
278
|
-
|
279
|
-
/*.mem_size =*/ lm_ggml_tensor_overhead() * LM_GGML_DEFAULT_GRAPH_SIZE,
|
280
|
-
/*.mem_buffer =*/ nullptr,
|
281
|
-
/*.no_alloc =*/ true,
|
282
|
-
};
|
283
|
-
lm_ggml_free(opt_ctx->ctx_copy);
|
284
|
-
opt_ctx->ctx_copy = lm_ggml_init(params);
|
285
|
-
}
|
286
|
-
|
287
|
-
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
288
|
-
|
289
|
-
lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
290
|
-
opt_ctx->allocated_graph = graph;
|
291
|
-
}
|
292
|
-
|
293
|
-
lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) {
|
294
|
-
lm_ggml_opt_context_t result = new struct lm_ggml_opt_context;
|
295
|
-
result->backend_sched = params.backend_sched;
|
296
|
-
result->ctx_compute = params.ctx_compute;
|
297
|
-
result->inputs = params.inputs;
|
298
|
-
result->outputs = params.outputs;
|
299
|
-
result->opt_period = params.opt_period;
|
300
|
-
result->get_opt_pars = params.get_opt_pars;
|
301
|
-
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
302
|
-
|
303
|
-
LM_GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
|
304
|
-
LM_GGML_ASSERT(result->opt_period >= 1);
|
305
|
-
|
306
|
-
const bool accumulate = params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD ||
|
307
|
-
(params.build_type == LM_GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
|
308
|
-
|
309
|
-
lm_ggml_set_input(result->inputs);
|
310
|
-
lm_ggml_set_output(result->outputs);
|
311
|
-
|
312
|
-
result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
313
|
-
lm_ggml_build_forward_expand(result->gf, result->outputs);
|
322
|
+
lm_ggml_set_input(opt_ctx->inputs);
|
323
|
+
lm_ggml_set_output(opt_ctx->outputs);
|
314
324
|
|
315
325
|
int n_param = 0;
|
316
|
-
for (int i = 0; i <
|
317
|
-
|
326
|
+
for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
|
327
|
+
const struct lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
328
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
318
329
|
n_param++;
|
319
330
|
}
|
331
|
+
LM_GGML_ASSERT(!(node->flags & LM_GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
|
320
332
|
}
|
321
333
|
|
322
|
-
{
|
334
|
+
if (!opt_ctx->ctx_static) {
|
323
335
|
// The static context is used for:
|
324
|
-
// - gradients (1 tensor per param if using gradient accumulation)
|
336
|
+
// - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
|
325
337
|
// - optimizer momenta (2 tensors per param)
|
326
|
-
// - labels
|
327
|
-
// - loss
|
328
|
-
// - pred
|
329
|
-
// - ncorrect (2 tensors).
|
330
|
-
|
331
|
-
const size_t
|
338
|
+
// - labels (if using static graphs)
|
339
|
+
// - loss (if using static graphs, up to 5 tensors)
|
340
|
+
// - pred (if using static graphs)
|
341
|
+
// - ncorrect (if using static graphs, 2 tensors).
|
342
|
+
constexpr size_t n_loss = 1;
|
343
|
+
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
344
|
+
(opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
345
|
+
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
346
|
+
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * lm_ggml_tensor_overhead();
|
332
347
|
struct lm_ggml_init_params params = {
|
333
348
|
/*.mem_size =*/ size_meta,
|
334
349
|
/*.mem_buffer =*/ nullptr,
|
335
350
|
/*.no_alloc =*/ true,
|
336
351
|
};
|
337
|
-
|
352
|
+
opt_ctx->ctx_static = lm_ggml_init(params);
|
338
353
|
}
|
354
|
+
LM_GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
|
355
|
+
|
339
356
|
{
|
340
|
-
// The
|
341
|
-
//
|
357
|
+
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
|
358
|
+
// It is used for:
|
359
|
+
// - optimizer parameters (1 shared for all optimizer invocations)
|
342
360
|
const size_t size_meta = 1 * lm_ggml_tensor_overhead();
|
343
361
|
struct lm_ggml_init_params params = {
|
344
362
|
/*.mem_size =*/ size_meta,
|
345
363
|
/*.mem_buffer =*/ nullptr,
|
346
364
|
/*.no_alloc =*/ true,
|
347
365
|
};
|
348
|
-
|
366
|
+
lm_ggml_free(opt_ctx->ctx_cpu);
|
367
|
+
opt_ctx->ctx_cpu = lm_ggml_init(params);
|
368
|
+
|
369
|
+
lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
370
|
+
opt_ctx->buf_cpu = nullptr;
|
349
371
|
}
|
350
372
|
|
373
|
+
struct lm_ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
|
351
374
|
|
352
|
-
switch (
|
375
|
+
switch (opt_ctx->loss_type) {
|
353
376
|
case LM_GGML_OPT_LOSS_TYPE_MEAN: {
|
354
|
-
|
355
|
-
lm_ggml_set_name(
|
356
|
-
const float scale = 1.0f / (
|
357
|
-
|
358
|
-
lm_ggml_set_name(
|
359
|
-
|
377
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
|
378
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum");
|
379
|
+
const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
|
380
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
|
381
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_mean");
|
382
|
+
opt_ctx->loss_per_datapoint = true;
|
360
383
|
break;
|
361
384
|
}
|
362
385
|
case LM_GGML_OPT_LOSS_TYPE_SUM: {
|
363
|
-
|
364
|
-
lm_ggml_set_name(
|
365
|
-
|
386
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->outputs);
|
387
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum");
|
388
|
+
opt_ctx->loss_per_datapoint = false;
|
366
389
|
break;
|
367
390
|
}
|
368
391
|
case LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
|
369
|
-
|
370
|
-
lm_ggml_set_input(
|
371
|
-
lm_ggml_set_name(
|
372
|
-
|
373
|
-
lm_ggml_set_name(
|
374
|
-
if (
|
375
|
-
|
376
|
-
lm_ggml_set_name(
|
392
|
+
opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
393
|
+
lm_ggml_set_input(opt_ctx->labels);
|
394
|
+
lm_ggml_set_name(opt_ctx->labels, "labels");
|
395
|
+
opt_ctx->loss = lm_ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
396
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
|
397
|
+
if (opt_ctx->opt_period > 1) {
|
398
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
|
399
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
|
377
400
|
}
|
378
|
-
|
401
|
+
opt_ctx->loss_per_datapoint = true;
|
379
402
|
break;
|
380
403
|
}
|
381
404
|
case LM_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
|
382
|
-
|
383
|
-
lm_ggml_set_input(
|
384
|
-
lm_ggml_set_name(
|
385
|
-
|
386
|
-
lm_ggml_set_name(
|
387
|
-
|
388
|
-
lm_ggml_set_name(
|
389
|
-
|
390
|
-
lm_ggml_set_name(
|
391
|
-
const float scale = 1.0f / (
|
392
|
-
|
393
|
-
lm_ggml_set_name(
|
394
|
-
|
405
|
+
opt_ctx->labels = lm_ggml_dup_tensor(ctx_results, opt_ctx->outputs);
|
406
|
+
lm_ggml_set_input(opt_ctx->labels);
|
407
|
+
lm_ggml_set_name(opt_ctx->labels, "labels");
|
408
|
+
opt_ctx->loss = lm_ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
|
409
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_error");
|
410
|
+
opt_ctx->loss = lm_ggml_sqr(ctx_results, opt_ctx->loss);
|
411
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_squared_error");
|
412
|
+
opt_ctx->loss = lm_ggml_sum(ctx_results, opt_ctx->loss);
|
413
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
|
414
|
+
const float scale = 1.0f / (opt_ctx->opt_period * lm_ggml_nelements(opt_ctx->outputs));
|
415
|
+
opt_ctx->loss = lm_ggml_scale(ctx_results, opt_ctx->loss, scale);
|
416
|
+
lm_ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
|
417
|
+
opt_ctx->loss_per_datapoint = true;
|
395
418
|
break;
|
396
419
|
}
|
397
420
|
}
|
398
|
-
lm_ggml_set_output(
|
399
|
-
lm_ggml_set_loss(
|
400
|
-
lm_ggml_build_forward_expand(
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
421
|
+
lm_ggml_set_output(opt_ctx->loss);
|
422
|
+
lm_ggml_set_loss(opt_ctx->loss);
|
423
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
|
424
|
+
|
425
|
+
if (opt_ctx->loss_type == LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
|
426
|
+
opt_ctx->pred = lm_ggml_argmax(ctx_results, opt_ctx->outputs);
|
427
|
+
lm_ggml_set_name(opt_ctx->pred, "pred");
|
428
|
+
lm_ggml_set_output(opt_ctx->pred);
|
429
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
|
430
|
+
|
431
|
+
opt_ctx->ncorrect = lm_ggml_count_equal(ctx_results, opt_ctx->pred, lm_ggml_argmax(ctx_results, opt_ctx->labels));
|
432
|
+
lm_ggml_set_name(opt_ctx->ncorrect, "ncorrect");
|
433
|
+
lm_ggml_set_output(opt_ctx->ncorrect);
|
434
|
+
lm_ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
|
435
|
+
}
|
406
436
|
|
407
|
-
if (
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
437
|
+
if (opt_ctx->buf_static) {
|
438
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
|
439
|
+
return;
|
440
|
+
}
|
441
|
+
} else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_FORWARD) {
|
442
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
|
443
|
+
opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
444
|
+
return;
|
414
445
|
}
|
415
446
|
|
416
|
-
if (
|
417
|
-
|
418
|
-
|
447
|
+
if (opt_ctx->grad_accs.empty()) {
|
448
|
+
LM_GGML_ASSERT(opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_GRAD);
|
449
|
+
|
450
|
+
const int n_nodes = opt_ctx->gf->n_nodes;
|
451
|
+
opt_ctx->grad_accs.resize(n_nodes);
|
452
|
+
for (int i = 0; i < n_nodes; ++i) {
|
453
|
+
lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
454
|
+
if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) {
|
455
|
+
opt_ctx->grad_accs[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
456
|
+
} else {
|
457
|
+
opt_ctx->grad_accs[i] = nullptr;
|
458
|
+
}
|
459
|
+
}
|
460
|
+
|
461
|
+
if (opt_ctx->build_type_alloc >= LM_GGML_OPT_BUILD_TYPE_OPT) {
|
462
|
+
opt_ctx->grad_m.resize(n_nodes);
|
463
|
+
opt_ctx->grad_v.resize(n_nodes);
|
464
|
+
for (int i = 0; i < n_nodes; ++i) {
|
465
|
+
lm_ggml_tensor * node = opt_ctx->gf->nodes[i];
|
466
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
467
|
+
opt_ctx->grad_m[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
468
|
+
opt_ctx->grad_v[i] = lm_ggml_new_tensor(opt_ctx->ctx_static, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
469
|
+
} else {
|
470
|
+
opt_ctx->grad_m[i] = nullptr;
|
471
|
+
opt_ctx->grad_v[i] = nullptr;
|
472
|
+
}
|
473
|
+
}
|
474
|
+
}
|
419
475
|
}
|
420
476
|
|
421
477
|
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
|
422
|
-
|
423
|
-
lm_ggml_build_backward_expand(
|
478
|
+
opt_ctx->gb_grad = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
|
479
|
+
lm_ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
|
424
480
|
|
425
|
-
if (
|
426
|
-
|
427
|
-
|
428
|
-
|
481
|
+
if (opt_ctx->buf_static) {
|
482
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) {
|
483
|
+
return;
|
484
|
+
}
|
485
|
+
} else if (opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_GRAD) {
|
486
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
487
|
+
lm_ggml_graph_reset(opt_ctx->gb_grad);
|
429
488
|
}
|
430
489
|
|
431
|
-
LM_GGML_ASSERT(
|
490
|
+
LM_GGML_ASSERT(opt_ctx->build_type_alloc == LM_GGML_OPT_BUILD_TYPE_OPT);
|
432
491
|
|
433
492
|
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
434
|
-
|
493
|
+
opt_ctx->gb_opt = lm_ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
435
494
|
|
436
|
-
|
437
|
-
lm_ggml_set_input(
|
438
|
-
lm_ggml_set_name(
|
495
|
+
opt_ctx->adamw_params = lm_ggml_new_tensor_1d(opt_ctx->ctx_cpu, LM_GGML_TYPE_F32, 7);
|
496
|
+
lm_ggml_set_input(opt_ctx->adamw_params);
|
497
|
+
lm_ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
439
498
|
|
440
|
-
for (int i =
|
441
|
-
struct lm_ggml_tensor * node =
|
442
|
-
struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(
|
499
|
+
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
500
|
+
struct lm_ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
501
|
+
struct lm_ggml_tensor * grad = lm_ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
443
502
|
|
444
|
-
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
445
|
-
struct lm_ggml_tensor * m =
|
446
|
-
struct lm_ggml_tensor * v =
|
447
|
-
struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(
|
448
|
-
|
503
|
+
if (grad && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
|
504
|
+
struct lm_ggml_tensor * m = opt_ctx->grad_m[i];
|
505
|
+
struct lm_ggml_tensor * v = opt_ctx->grad_v[i];
|
506
|
+
struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
507
|
+
|
508
|
+
lm_ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
509
|
+
lm_ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
510
|
+
lm_ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
511
|
+
|
512
|
+
lm_ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
449
513
|
}
|
450
514
|
}
|
451
515
|
|
452
|
-
|
453
|
-
|
516
|
+
if (!opt_ctx->buf_static) {
|
517
|
+
opt_ctx->buf_static = lm_ggml_backend_alloc_ctx_tensors(
|
518
|
+
opt_ctx->ctx_static, lm_ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
|
519
|
+
lm_ggml_graph_reset(opt_ctx->gb_opt);
|
520
|
+
}
|
454
521
|
|
455
|
-
|
522
|
+
opt_ctx->buf_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, lm_ggml_backend_cpu_buffer_type());
|
523
|
+
}
|
456
524
|
|
457
|
-
|
525
|
+
lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) {
|
526
|
+
lm_ggml_opt_context_t result = new struct lm_ggml_opt_context;
|
527
|
+
result->backend_sched = params.backend_sched;
|
528
|
+
result->ctx_compute = params.ctx_compute;
|
529
|
+
result->loss_type = params.loss_type;
|
530
|
+
result->build_type = params.build_type;
|
531
|
+
result->build_type_alloc = params.build_type;
|
532
|
+
result->inputs = params.inputs;
|
533
|
+
result->outputs = params.outputs;
|
534
|
+
result->opt_period = params.opt_period;
|
535
|
+
result->get_opt_pars = params.get_opt_pars;
|
536
|
+
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
537
|
+
|
538
|
+
LM_GGML_ASSERT(result->opt_period >= 1);
|
539
|
+
|
540
|
+
result->static_graphs = result->ctx_compute;
|
541
|
+
|
542
|
+
if (!result->static_graphs) {
|
543
|
+
LM_GGML_ASSERT(!result->inputs);
|
544
|
+
LM_GGML_ASSERT(!result->outputs);
|
545
|
+
return result;
|
546
|
+
}
|
547
|
+
|
548
|
+
LM_GGML_ASSERT(result->inputs);
|
549
|
+
LM_GGML_ASSERT(result->outputs);
|
550
|
+
|
551
|
+
result->gf = lm_ggml_new_graph_custom(result->ctx_compute, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
|
552
|
+
lm_ggml_build_forward_expand(result->gf, result->outputs);
|
553
|
+
|
554
|
+
lm_ggml_opt_build(result);
|
458
555
|
|
459
556
|
return result;
|
460
557
|
}
|
@@ -464,9 +561,9 @@ void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx) {
|
|
464
561
|
return;
|
465
562
|
}
|
466
563
|
lm_ggml_backend_buffer_free(opt_ctx->buf_static);
|
467
|
-
lm_ggml_backend_buffer_free(opt_ctx->
|
564
|
+
lm_ggml_backend_buffer_free(opt_ctx->buf_cpu);
|
468
565
|
lm_ggml_free(opt_ctx->ctx_static);
|
469
|
-
lm_ggml_free(opt_ctx->
|
566
|
+
lm_ggml_free(opt_ctx->ctx_cpu);
|
470
567
|
delete opt_ctx;
|
471
568
|
}
|
472
569
|
|
@@ -479,6 +576,10 @@ void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer) {
|
|
479
576
|
}
|
480
577
|
}
|
481
578
|
|
579
|
+
bool lm_ggml_opt_static_graphs(lm_ggml_opt_context_t opt_ctx) {
|
580
|
+
return opt_ctx->static_graphs;
|
581
|
+
}
|
582
|
+
|
482
583
|
struct lm_ggml_tensor * lm_ggml_opt_inputs(lm_ggml_opt_context_t opt_ctx) {
|
483
584
|
return opt_ctx->inputs;
|
484
585
|
}
|
@@ -582,8 +683,79 @@ void lm_ggml_opt_result_accuracy(lm_ggml_opt_result_t result, double * accuracy,
|
|
582
683
|
|
583
684
|
// ====== Computation ======
|
584
685
|
|
585
|
-
|
586
|
-
|
686
|
+
void lm_ggml_opt_prepare_alloc(
|
687
|
+
lm_ggml_opt_context_t opt_ctx,
|
688
|
+
struct lm_ggml_context * ctx_compute,
|
689
|
+
struct lm_ggml_cgraph * gf,
|
690
|
+
struct lm_ggml_tensor * inputs,
|
691
|
+
struct lm_ggml_tensor * outputs) {
|
692
|
+
LM_GGML_ASSERT(!opt_ctx->static_graphs);
|
693
|
+
opt_ctx->ctx_compute = ctx_compute;
|
694
|
+
opt_ctx->gf = gf;
|
695
|
+
opt_ctx->inputs = inputs;
|
696
|
+
opt_ctx->outputs = outputs;
|
697
|
+
}
|
698
|
+
|
699
|
+
void lm_ggml_opt_alloc(lm_ggml_opt_context_t opt_ctx, bool backward) {
|
700
|
+
LM_GGML_ASSERT(!opt_ctx->eval_ready);
|
701
|
+
if (opt_ctx->build_type == LM_GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
|
702
|
+
lm_ggml_graph_reset(opt_ctx->gb_grad);
|
703
|
+
}
|
704
|
+
if (backward) {
|
705
|
+
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
706
|
+
opt_ctx->build_type = opt_i_next == 0 ? LM_GGML_OPT_BUILD_TYPE_OPT : LM_GGML_OPT_BUILD_TYPE_GRAD;
|
707
|
+
} else {
|
708
|
+
opt_ctx->build_type = LM_GGML_OPT_BUILD_TYPE_FORWARD;
|
709
|
+
}
|
710
|
+
|
711
|
+
if (!opt_ctx->static_graphs) {
|
712
|
+
lm_ggml_opt_build(opt_ctx);
|
713
|
+
}
|
714
|
+
|
715
|
+
struct lm_ggml_cgraph * graph = nullptr;
|
716
|
+
switch (opt_ctx->build_type) {
|
717
|
+
case LM_GGML_OPT_BUILD_TYPE_FORWARD: {
|
718
|
+
graph = opt_ctx->gf;
|
719
|
+
} break;
|
720
|
+
case LM_GGML_OPT_BUILD_TYPE_GRAD: {
|
721
|
+
graph = opt_ctx->gb_grad;
|
722
|
+
} break;
|
723
|
+
case LM_GGML_OPT_BUILD_TYPE_OPT: {
|
724
|
+
graph = opt_ctx->gb_opt;
|
725
|
+
} break;
|
726
|
+
}
|
727
|
+
LM_GGML_ASSERT(graph);
|
728
|
+
|
729
|
+
if (opt_ctx->allocated_graph == graph) {
|
730
|
+
opt_ctx->eval_ready = true;
|
731
|
+
return;
|
732
|
+
}
|
733
|
+
|
734
|
+
lm_ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
|
735
|
+
|
736
|
+
if (opt_ctx->static_graphs) {
|
737
|
+
lm_ggml_init_params params = {
|
738
|
+
/*.mem_size =*/ graph->size*lm_ggml_tensor_overhead() + lm_ggml_graph_overhead_custom(graph->size, graph->grads),
|
739
|
+
/*.mem_buffer =*/ nullptr,
|
740
|
+
/*.no_alloc =*/ true,
|
741
|
+
};
|
742
|
+
lm_ggml_free(opt_ctx->ctx_copy);
|
743
|
+
opt_ctx->ctx_copy = lm_ggml_init(params);
|
744
|
+
|
745
|
+
opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
|
746
|
+
} else {
|
747
|
+
opt_ctx->allocated_graph_copy = graph;
|
748
|
+
}
|
749
|
+
|
750
|
+
lm_ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
751
|
+
opt_ctx->allocated_graph = graph;
|
752
|
+
|
753
|
+
opt_ctx->eval_ready = true;
|
754
|
+
}
|
755
|
+
|
756
|
+
void lm_ggml_opt_eval(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result) {
|
757
|
+
LM_GGML_ASSERT(opt_ctx->eval_ready);
|
758
|
+
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
587
759
|
struct lm_ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
588
760
|
|
589
761
|
LM_GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
@@ -609,9 +781,19 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
|
|
609
781
|
adamw_par_data[6] = beta2h;
|
610
782
|
}
|
611
783
|
|
612
|
-
lm_ggml_opt_alloc_graph(opt_ctx, graph);
|
613
784
|
lm_ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
614
785
|
opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
|
786
|
+
opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
787
|
+
|
788
|
+
if (!opt_ctx->static_graphs) {
|
789
|
+
opt_ctx->gf = nullptr;
|
790
|
+
opt_ctx->gb_grad = nullptr;
|
791
|
+
opt_ctx->gb_opt = nullptr;
|
792
|
+
opt_ctx->allocated_graph = nullptr;
|
793
|
+
opt_ctx->allocated_graph_copy = nullptr;
|
794
|
+
}
|
795
|
+
|
796
|
+
opt_ctx->eval_ready = false;
|
615
797
|
|
616
798
|
if (!result) {
|
617
799
|
return;
|
@@ -635,12 +817,14 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
|
|
635
817
|
lm_ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, lm_ggml_nbytes(opt_ctx->loss));
|
636
818
|
result->loss.push_back(loss);
|
637
819
|
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
820
|
+
if (opt_ctx->pred) {
|
821
|
+
LM_GGML_ASSERT(opt_ctx->pred->type == LM_GGML_TYPE_I32);
|
822
|
+
std::vector<int32_t> pred(ndata);
|
823
|
+
lm_ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, lm_ggml_nbytes(opt_ctx->pred));
|
824
|
+
result->pred.insert(result->pred.end(), pred.begin(), pred.end());
|
825
|
+
}
|
642
826
|
|
643
|
-
if (!opt_ctx->
|
827
|
+
if (!opt_ctx->ncorrect || result->ncorrect < 0) {
|
644
828
|
result->ncorrect = -1;
|
645
829
|
return;
|
646
830
|
}
|
@@ -652,26 +836,6 @@ static void lm_ggml_opt_eval_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph
|
|
652
836
|
result->ncorrect += ncorrect;
|
653
837
|
}
|
654
838
|
|
655
|
-
void lm_ggml_opt_forward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) {
|
656
|
-
lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
|
657
|
-
}
|
658
|
-
|
659
|
-
void lm_ggml_opt_forward_backward(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result * result) {
|
660
|
-
if (opt_ctx->opt_period == 1) {
|
661
|
-
lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
662
|
-
return;
|
663
|
-
}
|
664
|
-
|
665
|
-
const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
|
666
|
-
if (opt_i_next == 0) {
|
667
|
-
lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
|
668
|
-
lm_ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
|
669
|
-
} else {
|
670
|
-
lm_ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
|
671
|
-
}
|
672
|
-
opt_ctx->opt_i = opt_i_next;
|
673
|
-
}
|
674
|
-
|
675
839
|
// ====== High-Level Functions ======
|
676
840
|
|
677
841
|
void lm_ggml_opt_epoch(
|
@@ -682,6 +846,7 @@ void lm_ggml_opt_epoch(
|
|
682
846
|
int64_t idata_split,
|
683
847
|
lm_ggml_opt_epoch_callback callback_train,
|
684
848
|
lm_ggml_opt_epoch_callback callback_eval) {
|
849
|
+
LM_GGML_ASSERT(lm_ggml_opt_static_graphs(opt_ctx) && "lm_ggml_opt_epoch requires static graphs");
|
685
850
|
struct lm_ggml_tensor * inputs = lm_ggml_opt_inputs(opt_ctx);
|
686
851
|
struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
|
687
852
|
struct lm_ggml_tensor * data = lm_ggml_opt_dataset_data(dataset);
|
@@ -700,16 +865,18 @@ void lm_ggml_opt_epoch(
|
|
700
865
|
int64_t ibatch = 0;
|
701
866
|
int64_t t_loop_start = lm_ggml_time_us();
|
702
867
|
for (; ibatch < ibatch_split; ++ibatch) {
|
868
|
+
lm_ggml_opt_alloc(opt_ctx, /*backward =*/ true);
|
703
869
|
lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
704
|
-
|
870
|
+
lm_ggml_opt_eval(opt_ctx, result_train);
|
705
871
|
if (callback_train) {
|
706
872
|
callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
|
707
873
|
}
|
708
874
|
}
|
709
875
|
t_loop_start = lm_ggml_time_us();
|
710
876
|
for (; ibatch < nbatches; ++ibatch) {
|
877
|
+
lm_ggml_opt_alloc(opt_ctx, /*backward =*/ false);
|
711
878
|
lm_ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
|
712
|
-
|
879
|
+
lm_ggml_opt_eval(opt_ctx, result_eval);
|
713
880
|
if (callback_eval) {
|
714
881
|
callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
|
715
882
|
}
|
@@ -726,13 +893,26 @@ void lm_ggml_opt_epoch_callback_progress_bar(
|
|
726
893
|
int64_t t_start_us) {
|
727
894
|
fprintf(stderr, "%s[", train ? "train: " : "val: ");
|
728
895
|
|
729
|
-
|
896
|
+
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
|
897
|
+
constexpr int64_t bar_length = 8;
|
898
|
+
const int64_t ibatch8 = 8 * ibatch;
|
730
899
|
for (int64_t j = 0; j < bar_length; ++j) {
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
900
|
+
if (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
|
901
|
+
fprintf(stderr, "\u2588"); // full block
|
902
|
+
} else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
|
903
|
+
fprintf(stderr, "\u2589"); // 7/8 filled
|
904
|
+
} else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
|
905
|
+
fprintf(stderr, "\u258A"); // 6/8 filled
|
906
|
+
} else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
|
907
|
+
fprintf(stderr, "\u258B"); // 5/8 filled
|
908
|
+
} else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
|
909
|
+
fprintf(stderr, "\u258C"); // 4/8 filled
|
910
|
+
} else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
|
911
|
+
fprintf(stderr, "\u258D"); // 3/8 filled
|
912
|
+
} else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
|
913
|
+
fprintf(stderr, "\u258E"); // 2/8 filled
|
914
|
+
} else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
|
915
|
+
fprintf(stderr, "\u258F"); // 1/8 filled
|
736
916
|
} else {
|
737
917
|
fprintf(stderr, " ");
|
738
918
|
}
|
@@ -764,8 +944,8 @@ void lm_ggml_opt_epoch_callback_progress_bar(
|
|
764
944
|
const int64_t t_eta_m = t_eta_s / 60;
|
765
945
|
t_eta_s -= t_eta_m * 60;
|
766
946
|
|
767
|
-
fprintf(stderr, "
|
768
|
-
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "
|
947
|
+
fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
|
948
|
+
"t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
|
769
949
|
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
|
770
950
|
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
|
771
951
|
if (ibatch == ibatch_max) {
|
@@ -806,7 +986,10 @@ void lm_ggml_opt_fit(
|
|
806
986
|
|
807
987
|
int64_t epoch = 1;
|
808
988
|
|
809
|
-
lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched,
|
989
|
+
lm_ggml_opt_params params = lm_ggml_opt_default_params(backend_sched, loss_type);
|
990
|
+
params.ctx_compute = ctx_compute;
|
991
|
+
params.inputs = inputs;
|
992
|
+
params.outputs = outputs;
|
810
993
|
params.opt_period = opt_period;
|
811
994
|
params.get_opt_pars = get_opt_pars;
|
812
995
|
params.get_opt_pars_ud = &epoch;
|