cui-llama.rn 1.6.1 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +41 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/ggml-opt.h
CHANGED
@@ -37,13 +37,16 @@ extern "C" {
|
|
37
37
|
// ====== Dataset ======
|
38
38
|
|
39
39
|
LM_GGML_API lm_ggml_opt_dataset_t lm_ggml_opt_dataset_init(
|
40
|
-
|
41
|
-
|
42
|
-
int64_t
|
43
|
-
int64_t
|
40
|
+
enum lm_ggml_type type_data, // the type for the internal data tensor
|
41
|
+
enum lm_ggml_type type_label, // the type for the internal labels tensor
|
42
|
+
int64_t ne_datapoint, // number of elements per datapoint
|
43
|
+
int64_t ne_label, // number of elements per label
|
44
|
+
int64_t ndata, // total number of datapoints/labels
|
45
|
+
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
44
46
|
LM_GGML_API void lm_ggml_opt_dataset_free(lm_ggml_opt_dataset_t dataset);
|
45
47
|
|
46
48
|
// get underlying tensors that store the data
|
49
|
+
LM_GGML_API int64_t lm_ggml_opt_dataset_ndata (lm_ggml_opt_dataset_t dataset);
|
47
50
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_data (lm_ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
|
48
51
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_dataset_labels(lm_ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
|
49
52
|
|
@@ -56,13 +59,19 @@ extern "C" {
|
|
56
59
|
struct lm_ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
|
57
60
|
struct lm_ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
|
58
61
|
int64_t ibatch);
|
62
|
+
LM_GGML_API void lm_ggml_opt_dataset_get_batch_host(
|
63
|
+
lm_ggml_opt_dataset_t dataset,
|
64
|
+
void * data_batch,
|
65
|
+
size_t nb_data_batch,
|
66
|
+
void * labels_batch,
|
67
|
+
int64_t ibatch);
|
59
68
|
|
60
69
|
// ====== Model / Context ======
|
61
70
|
|
62
71
|
enum lm_ggml_opt_build_type {
|
63
|
-
LM_GGML_OPT_BUILD_TYPE_FORWARD,
|
64
|
-
LM_GGML_OPT_BUILD_TYPE_GRAD,
|
65
|
-
LM_GGML_OPT_BUILD_TYPE_OPT,
|
72
|
+
LM_GGML_OPT_BUILD_TYPE_FORWARD = 10,
|
73
|
+
LM_GGML_OPT_BUILD_TYPE_GRAD = 20,
|
74
|
+
LM_GGML_OPT_BUILD_TYPE_OPT = 30,
|
66
75
|
};
|
67
76
|
|
68
77
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
@@ -81,20 +90,22 @@ extern "C" {
|
|
81
90
|
// userdata can be used to pass arbitrary data
|
82
91
|
typedef struct lm_ggml_opt_optimizer_params (*lm_ggml_opt_get_optimizer_params)(void * userdata);
|
83
92
|
|
84
|
-
// returns the default optimizer params (constant)
|
93
|
+
// returns the default optimizer params (constant, hard-coded values)
|
85
94
|
// userdata is not used
|
86
95
|
LM_GGML_API struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(void * userdata);
|
87
96
|
|
97
|
+
// casts userdata to lm_ggml_opt_optimizer_params and returns it
|
98
|
+
LM_GGML_API struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_constant_optimizer_params(void * userdata);
|
99
|
+
|
88
100
|
// parameters for initializing a new optimization context
|
89
101
|
struct lm_ggml_opt_params {
|
90
102
|
lm_ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
|
91
103
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
struct lm_ggml_tensor
|
97
|
-
struct lm_ggml_tensor * outputs;
|
104
|
+
// by default the forward graph needs to be reconstructed for each eval
|
105
|
+
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
|
106
|
+
struct lm_ggml_context * ctx_compute;
|
107
|
+
struct lm_ggml_tensor * inputs;
|
108
|
+
struct lm_ggml_tensor * outputs;
|
98
109
|
|
99
110
|
enum lm_ggml_opt_loss_type loss_type;
|
100
111
|
enum lm_ggml_opt_build_type build_type;
|
@@ -107,12 +118,9 @@ extern "C" {
|
|
107
118
|
|
108
119
|
// get parameters for an optimization context with defaults set where possible
|
109
120
|
// parameters for which no sensible defaults exist are supplied as arguments to this function
|
110
|
-
LM_GGML_API lm_ggml_opt_params lm_ggml_opt_default_params(
|
111
|
-
lm_ggml_backend_sched_t
|
112
|
-
|
113
|
-
struct lm_ggml_tensor * inputs,
|
114
|
-
struct lm_ggml_tensor * outputs,
|
115
|
-
enum lm_ggml_opt_loss_type loss_type);
|
121
|
+
LM_GGML_API struct lm_ggml_opt_params lm_ggml_opt_default_params(
|
122
|
+
lm_ggml_backend_sched_t backend_sched,
|
123
|
+
enum lm_ggml_opt_loss_type loss_type);
|
116
124
|
|
117
125
|
LM_GGML_API lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params);
|
118
126
|
LM_GGML_API void lm_ggml_opt_free(lm_ggml_opt_context_t opt_ctx);
|
@@ -120,7 +128,10 @@ extern "C" {
|
|
120
128
|
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
121
129
|
LM_GGML_API void lm_ggml_opt_reset(lm_ggml_opt_context_t opt_ctx, bool optimizer);
|
122
130
|
|
131
|
+
LM_GGML_API bool lm_ggml_opt_static_graphs(lm_ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
132
|
+
|
123
133
|
// get underlying tensors that store data
|
134
|
+
// if not using static graphs these pointers become invalid with the next call to lm_ggml_opt_alloc
|
124
135
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_inputs( lm_ggml_opt_context_t opt_ctx); // forward graph input tensor
|
125
136
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_outputs( lm_ggml_opt_context_t opt_ctx); // forward graph output tensor
|
126
137
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_labels( lm_ggml_opt_context_t opt_ctx); // labels to compare outputs against
|
@@ -128,11 +139,12 @@ extern "C" {
|
|
128
139
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_pred( lm_ggml_opt_context_t opt_ctx); // predictions made by outputs
|
129
140
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_ncorrect(lm_ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
|
130
141
|
|
142
|
+
// get the gradient accumulator for a node from the forward graph
|
131
143
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_opt_grad_acc(lm_ggml_opt_context_t opt_ctx, struct lm_ggml_tensor * node);
|
132
144
|
|
133
145
|
// ====== Optimization Result ======
|
134
146
|
|
135
|
-
LM_GGML_API lm_ggml_opt_result_t lm_ggml_opt_result_init();
|
147
|
+
LM_GGML_API lm_ggml_opt_result_t lm_ggml_opt_result_init(void);
|
136
148
|
LM_GGML_API void lm_ggml_opt_result_free(lm_ggml_opt_result_t result);
|
137
149
|
LM_GGML_API void lm_ggml_opt_result_reset(lm_ggml_opt_result_t result);
|
138
150
|
|
@@ -144,11 +156,20 @@ extern "C" {
|
|
144
156
|
|
145
157
|
// ====== Computation ======
|
146
158
|
|
147
|
-
//
|
148
|
-
LM_GGML_API void
|
159
|
+
// if not using static graphs, this function must be called prior to lm_ggml_opt_alloc
|
160
|
+
LM_GGML_API void lm_ggml_opt_prepare_alloc(
|
161
|
+
lm_ggml_opt_context_t opt_ctx,
|
162
|
+
struct lm_ggml_context * ctx_compute,
|
163
|
+
struct lm_ggml_cgraph * gf,
|
164
|
+
struct lm_ggml_tensor * inputs,
|
165
|
+
struct lm_ggml_tensor * outputs);
|
166
|
+
|
167
|
+
// allocate the next graph for evaluation, either forward or forward + backward
|
168
|
+
// must be called exactly once prior to calling lm_ggml_opt_eval
|
169
|
+
LM_GGML_API void lm_ggml_opt_alloc(lm_ggml_opt_context_t opt_ctx, bool backward);
|
149
170
|
|
150
|
-
// do forward pass, increment result if not NULL, do backward pass
|
151
|
-
LM_GGML_API void
|
171
|
+
// do forward pass, increment result if not NULL, do backward pass if allocated
|
172
|
+
LM_GGML_API void lm_ggml_opt_eval(lm_ggml_opt_context_t opt_ctx, lm_ggml_opt_result_t result);
|
152
173
|
|
153
174
|
// ############################################################################
|
154
175
|
// ## The high-level functions start here. They do not depend on any private ##
|
@@ -200,9 +221,9 @@ extern "C" {
|
|
200
221
|
// fit model defined by inputs and outputs to dataset
|
201
222
|
LM_GGML_API void lm_ggml_opt_fit(
|
202
223
|
lm_ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
|
203
|
-
lm_ggml_context
|
204
|
-
lm_ggml_tensor
|
205
|
-
lm_ggml_tensor
|
224
|
+
struct lm_ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
225
|
+
struct lm_ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
226
|
+
struct lm_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
206
227
|
lm_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
207
228
|
enum lm_ggml_opt_loss_type loss_type, // loss to minimize
|
208
229
|
lm_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
package/cpp/ggml-quants.c
CHANGED
@@ -19,12 +19,6 @@
|
|
19
19
|
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
20
20
|
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
21
21
|
|
22
|
-
#if defined(_MSC_VER)
|
23
|
-
// disable "possible loss of data" to avoid warnings for hundreds of casts
|
24
|
-
// we should just be careful :)
|
25
|
-
#pragma warning(disable: 4244 4267)
|
26
|
-
#endif
|
27
|
-
|
28
22
|
#define UNUSED LM_GGML_UNUSED
|
29
23
|
|
30
24
|
// reference implementation for deterministic creation of model files
|
package/cpp/ggml.c
CHANGED
@@ -64,12 +64,17 @@
|
|
64
64
|
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
|
65
65
|
float lm_ggml_table_f32_f16[1 << 16];
|
66
66
|
|
67
|
-
#if
|
68
|
-
(
|
67
|
+
#if defined(__linux__) || \
|
68
|
+
defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
|
69
|
+
(defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
|
70
|
+
|
69
71
|
#include <unistd.h>
|
70
72
|
#include <sys/types.h>
|
71
73
|
#include <sys/stat.h>
|
72
74
|
#include <sys/wait.h>
|
75
|
+
#if defined(__linux__)
|
76
|
+
#include <sys/prctl.h>
|
77
|
+
#endif
|
73
78
|
|
74
79
|
#if defined(__ANDROID__)
|
75
80
|
#include <unwind.h>
|
@@ -133,10 +138,36 @@ static void lm_ggml_print_backtrace(void) {
|
|
133
138
|
if (LM_GGML_NO_BACKTRACE) {
|
134
139
|
return;
|
135
140
|
}
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
141
|
+
#if defined(__linux__)
|
142
|
+
FILE * f = fopen("/proc/self/status", "r");
|
143
|
+
size_t size = 0;
|
144
|
+
char * line = NULL;
|
145
|
+
ssize_t length = 0;
|
146
|
+
while ((length = getline(&line, &size, f)) > 0) {
|
147
|
+
if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
|
148
|
+
(length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
|
149
|
+
// Already being debugged, and the breakpoint is the later abort()
|
150
|
+
free(line);
|
151
|
+
fclose(f);
|
152
|
+
return;
|
153
|
+
}
|
154
|
+
}
|
155
|
+
free(line);
|
156
|
+
fclose(f);
|
157
|
+
int lock[2] = { -1, -1 };
|
158
|
+
(void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
|
159
|
+
#endif
|
160
|
+
const int parent_pid = getpid();
|
161
|
+
const int child_pid = fork();
|
162
|
+
if (child_pid < 0) { // error
|
163
|
+
return;
|
164
|
+
} else if (child_pid == 0) { // child
|
165
|
+
char attach[32];
|
166
|
+
snprintf(attach, sizeof(attach), "attach %d", parent_pid);
|
167
|
+
#if defined(__linux__)
|
168
|
+
close(lock[1]);
|
169
|
+
(void) !read(lock[0], lock, 1);
|
170
|
+
#endif
|
140
171
|
// try gdb
|
141
172
|
execlp("gdb", "gdb", "--batch",
|
142
173
|
"-ex", "set style enabled on",
|
@@ -149,18 +180,18 @@ static void lm_ggml_print_backtrace(void) {
|
|
149
180
|
execlp("lldb", "lldb", "--batch",
|
150
181
|
"-o", "bt",
|
151
182
|
"-o", "quit",
|
152
|
-
"-p", attach,
|
183
|
+
"-p", &attach[sizeof("attach ") - 1],
|
153
184
|
(char *) NULL);
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
185
|
+
// gdb failed, fallback to backtrace_symbols
|
186
|
+
lm_ggml_print_backtrace_symbols();
|
187
|
+
_Exit(0);
|
188
|
+
} else { // parent
|
189
|
+
#if defined(__linux__)
|
190
|
+
prctl(PR_SET_PTRACER, child_pid);
|
191
|
+
close(lock[1]);
|
192
|
+
close(lock[0]);
|
193
|
+
#endif
|
194
|
+
waitpid(child_pid, NULL, 0);
|
164
195
|
}
|
165
196
|
}
|
166
197
|
#else
|
@@ -1081,9 +1112,10 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = {
|
|
1081
1112
|
"HARDSWISH",
|
1082
1113
|
"HARDSIGMOID",
|
1083
1114
|
"EXP",
|
1115
|
+
"GELU_ERF",
|
1084
1116
|
};
|
1085
1117
|
|
1086
|
-
static_assert(LM_GGML_UNARY_OP_COUNT ==
|
1118
|
+
static_assert(LM_GGML_UNARY_OP_COUNT == 15, "LM_GGML_UNARY_OP_COUNT != 15");
|
1087
1119
|
|
1088
1120
|
|
1089
1121
|
static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN");
|
@@ -1312,6 +1344,10 @@ bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) {
|
|
1312
1344
|
return lm_ggml_is_contiguous_n(tensor, 2);
|
1313
1345
|
}
|
1314
1346
|
|
1347
|
+
bool lm_ggml_is_contiguously_allocated(const struct lm_ggml_tensor * tensor) {
|
1348
|
+
return lm_ggml_nbytes(tensor) == lm_ggml_nelements(tensor) * lm_ggml_type_size(tensor->type)/lm_ggml_blck_size(tensor->type);
|
1349
|
+
}
|
1350
|
+
|
1315
1351
|
bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) {
|
1316
1352
|
static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
|
1317
1353
|
|
@@ -2479,6 +2515,20 @@ struct lm_ggml_tensor * lm_ggml_gelu_inplace(
|
|
2479
2515
|
return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU);
|
2480
2516
|
}
|
2481
2517
|
|
2518
|
+
// lm_ggml_gelu_erf
|
2519
|
+
|
2520
|
+
struct lm_ggml_tensor * lm_ggml_gelu_erf(
|
2521
|
+
struct lm_ggml_context * ctx,
|
2522
|
+
struct lm_ggml_tensor * a) {
|
2523
|
+
return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_GELU_ERF);
|
2524
|
+
}
|
2525
|
+
|
2526
|
+
struct lm_ggml_tensor * lm_ggml_gelu_erf_inplace(
|
2527
|
+
struct lm_ggml_context * ctx,
|
2528
|
+
struct lm_ggml_tensor * a) {
|
2529
|
+
return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_GELU_ERF);
|
2530
|
+
}
|
2531
|
+
|
2482
2532
|
// lm_ggml_gelu_quick
|
2483
2533
|
|
2484
2534
|
struct lm_ggml_tensor * lm_ggml_gelu_quick(
|
@@ -2741,11 +2791,11 @@ void lm_ggml_mul_mat_set_prec(
|
|
2741
2791
|
c = lm_ggml_mul_mat_id(ctx, as, b, ids);
|
2742
2792
|
|
2743
2793
|
as -> [cols, rows, n_expert]
|
2744
|
-
ids -> [n_experts_used, n_tokens] (i32)
|
2745
2794
|
b -> [cols, n_expert_used, n_tokens]
|
2795
|
+
ids -> [n_expert_used, n_tokens] (i32)
|
2746
2796
|
c -> [rows, n_expert_used, n_tokens]
|
2747
2797
|
|
2748
|
-
in b,
|
2798
|
+
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
2749
2799
|
|
2750
2800
|
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
2751
2801
|
*/
|
@@ -5508,7 +5558,7 @@ static void lm_ggml_compute_backward(
|
|
5508
5558
|
// tensor = src0 * 1 + src1 * 0
|
5509
5559
|
if (src0_needs_grads) {
|
5510
5560
|
// dsrc0 = dtensor * 1
|
5511
|
-
lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
5561
|
+
lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_reshape(ctx, grad, src0));
|
5512
5562
|
}
|
5513
5563
|
if (src1_needs_grads) {
|
5514
5564
|
// dsrc1 = dtensor * 0 -> noop
|
@@ -5789,10 +5839,9 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
|
|
5789
5839
|
}
|
5790
5840
|
|
5791
5841
|
void lm_ggml_build_backward_expand(
|
5792
|
-
struct lm_ggml_context *
|
5793
|
-
struct
|
5794
|
-
struct
|
5795
|
-
bool accumulate) {
|
5842
|
+
struct lm_ggml_context * ctx,
|
5843
|
+
struct lm_ggml_cgraph * cgraph,
|
5844
|
+
struct lm_ggml_tensor ** grad_accs) {
|
5796
5845
|
LM_GGML_ASSERT(cgraph->n_nodes > 0);
|
5797
5846
|
LM_GGML_ASSERT(cgraph->grads);
|
5798
5847
|
LM_GGML_ASSERT(cgraph->grad_accs);
|
@@ -5865,21 +5914,24 @@ void lm_ggml_build_backward_expand(
|
|
5865
5914
|
LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW ||
|
5866
5915
|
node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE);
|
5867
5916
|
|
5868
|
-
const size_t
|
5869
|
-
LM_GGML_ASSERT(
|
5870
|
-
LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used,
|
5871
|
-
if (
|
5872
|
-
cgraph->grad_accs[
|
5873
|
-
cgraph->grads[
|
5874
|
-
|
5917
|
+
const size_t ihash = lm_ggml_hash_find(&cgraph->visited_hash_set, node);
|
5918
|
+
LM_GGML_ASSERT(ihash != LM_GGML_HASHSET_FULL);
|
5919
|
+
LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
|
5920
|
+
if (grad_accs && grad_accs[i]) {
|
5921
|
+
cgraph->grad_accs[ihash] = grad_accs[i];
|
5922
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
5923
|
+
} else if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) {
|
5924
|
+
// loss tensors always need a gradient accumulator
|
5925
|
+
cgraph->grad_accs[ihash] = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, node->ne);
|
5926
|
+
cgraph->grads[ihash] = cgraph->grad_accs[ihash];
|
5875
5927
|
}
|
5876
|
-
grads_needed[
|
5928
|
+
grads_needed[ihash] = true;
|
5877
5929
|
}
|
5878
5930
|
|
5879
5931
|
for (int i = n_nodes_f - 1; i >= 0; --i) {
|
5880
5932
|
// inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation
|
5881
5933
|
// use allocator to automatically make inplace operations
|
5882
|
-
lm_ggml_compute_backward(
|
5934
|
+
lm_ggml_compute_backward(ctx, cgraph, i, grads_needed);
|
5883
5935
|
}
|
5884
5936
|
|
5885
5937
|
free(grads_needed);
|
@@ -6025,8 +6077,8 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst)
|
|
6025
6077
|
}
|
6026
6078
|
}
|
6027
6079
|
|
6028
|
-
struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph) {
|
6029
|
-
struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads
|
6080
|
+
struct lm_ggml_cgraph * lm_ggml_graph_dup(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, bool force_grads) {
|
6081
|
+
struct lm_ggml_cgraph * result = lm_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
|
6030
6082
|
lm_ggml_graph_cpy(cgraph, result);
|
6031
6083
|
return result;
|
6032
6084
|
}
|
@@ -6045,6 +6097,9 @@ struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) {
|
|
6045
6097
|
}
|
6046
6098
|
|
6047
6099
|
void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) {
|
6100
|
+
if (!cgraph) {
|
6101
|
+
return;
|
6102
|
+
}
|
6048
6103
|
LM_GGML_ASSERT(cgraph->grads != NULL);
|
6049
6104
|
|
6050
6105
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
@@ -6354,8 +6409,8 @@ void lm_ggml_set_output(struct lm_ggml_tensor * tensor) {
|
|
6354
6409
|
tensor->flags |= LM_GGML_TENSOR_FLAG_OUTPUT;
|
6355
6410
|
}
|
6356
6411
|
|
6357
|
-
void lm_ggml_set_param(struct
|
6358
|
-
|
6412
|
+
void lm_ggml_set_param(struct lm_ggml_tensor * tensor) {
|
6413
|
+
LM_GGML_ASSERT(tensor->op == LM_GGML_OP_NONE);
|
6359
6414
|
tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
|
6360
6415
|
}
|
6361
6416
|
|
package/cpp/ggml.h
CHANGED
@@ -537,6 +537,7 @@ extern "C" {
|
|
537
537
|
LM_GGML_UNARY_OP_HARDSWISH,
|
538
538
|
LM_GGML_UNARY_OP_HARDSIGMOID,
|
539
539
|
LM_GGML_UNARY_OP_EXP,
|
540
|
+
LM_GGML_UNARY_OP_GELU_ERF,
|
540
541
|
|
541
542
|
LM_GGML_UNARY_OP_COUNT,
|
542
543
|
};
|
@@ -674,11 +675,15 @@ extern "C" {
|
|
674
675
|
LM_GGML_API bool lm_ggml_is_3d (const struct lm_ggml_tensor * tensor);
|
675
676
|
LM_GGML_API int lm_ggml_n_dims (const struct lm_ggml_tensor * tensor); // returns 1 for scalars
|
676
677
|
|
678
|
+
// returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
|
677
679
|
LM_GGML_API bool lm_ggml_is_contiguous (const struct lm_ggml_tensor * tensor);
|
678
680
|
LM_GGML_API bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor); // same as lm_ggml_is_contiguous()
|
679
681
|
LM_GGML_API bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 1
|
680
682
|
LM_GGML_API bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor); // contiguous for dims >= 2
|
681
683
|
|
684
|
+
// returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
|
685
|
+
LM_GGML_API bool lm_ggml_is_contiguously_allocated(const struct lm_ggml_tensor * tensor);
|
686
|
+
|
682
687
|
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
683
688
|
LM_GGML_API bool lm_ggml_is_contiguous_channels(const struct lm_ggml_tensor * tensor);
|
684
689
|
|
@@ -765,7 +770,7 @@ extern "C" {
|
|
765
770
|
// Tensor flags
|
766
771
|
LM_GGML_API void lm_ggml_set_input(struct lm_ggml_tensor * tensor);
|
767
772
|
LM_GGML_API void lm_ggml_set_output(struct lm_ggml_tensor * tensor);
|
768
|
-
LM_GGML_API void lm_ggml_set_param(struct
|
773
|
+
LM_GGML_API void lm_ggml_set_param(struct lm_ggml_tensor * tensor);
|
769
774
|
LM_GGML_API void lm_ggml_set_loss(struct lm_ggml_tensor * tensor);
|
770
775
|
|
771
776
|
//
|
@@ -935,7 +940,7 @@ extern "C" {
|
|
935
940
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_repeat_back(
|
936
941
|
struct lm_ggml_context * ctx,
|
937
942
|
struct lm_ggml_tensor * a,
|
938
|
-
struct lm_ggml_tensor * b);
|
943
|
+
struct lm_ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
|
939
944
|
|
940
945
|
// concat a and b along dim
|
941
946
|
// used in stable-diffusion
|
@@ -1021,6 +1026,16 @@ extern "C" {
|
|
1021
1026
|
struct lm_ggml_context * ctx,
|
1022
1027
|
struct lm_ggml_tensor * a);
|
1023
1028
|
|
1029
|
+
// GELU using erf (error function) when possible
|
1030
|
+
// some backends may fallback to approximation based on Abramowitz and Stegun formula
|
1031
|
+
LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf(
|
1032
|
+
struct lm_ggml_context * ctx,
|
1033
|
+
struct lm_ggml_tensor * a);
|
1034
|
+
|
1035
|
+
LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_erf_inplace(
|
1036
|
+
struct lm_ggml_context * ctx,
|
1037
|
+
struct lm_ggml_tensor * a);
|
1038
|
+
|
1024
1039
|
LM_GGML_API struct lm_ggml_tensor * lm_ggml_gelu_quick(
|
1025
1040
|
struct lm_ggml_context * ctx,
|
1026
1041
|
struct lm_ggml_tensor * a);
|
@@ -2046,15 +2061,14 @@ extern "C" {
|
|
2046
2061
|
|
2047
2062
|
LM_GGML_API void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor);
|
2048
2063
|
LM_GGML_API void lm_ggml_build_backward_expand(
|
2049
|
-
struct lm_ggml_context *
|
2050
|
-
struct
|
2051
|
-
struct
|
2052
|
-
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
|
2064
|
+
struct lm_ggml_context * ctx, // context for gradient computation
|
2065
|
+
struct lm_ggml_cgraph * cgraph,
|
2066
|
+
struct lm_ggml_tensor ** grad_accs);
|
2053
2067
|
|
2054
2068
|
// graph allocation in a context
|
2055
2069
|
LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph (struct lm_ggml_context * ctx); // size = LM_GGML_DEFAULT_GRAPH_SIZE, grads = false
|
2056
2070
|
LM_GGML_API struct lm_ggml_cgraph * lm_ggml_new_graph_custom(struct lm_ggml_context * ctx, size_t size, bool grads);
|
2057
|
-
LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph);
|
2071
|
+
LM_GGML_API struct lm_ggml_cgraph * lm_ggml_graph_dup (struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, bool force_grads);
|
2058
2072
|
LM_GGML_API void lm_ggml_graph_cpy (struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst);
|
2059
2073
|
LM_GGML_API void lm_ggml_graph_reset (struct lm_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
|
2060
2074
|
LM_GGML_API void lm_ggml_graph_clear (struct lm_ggml_cgraph * cgraph);
|