cui-llama.rn 1.6.1 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +38 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +41 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/llama-context.h
CHANGED
@@ -7,6 +7,7 @@
|
|
7
7
|
#include "llama-adapter.h"
|
8
8
|
|
9
9
|
#include "ggml-cpp.h"
|
10
|
+
#include "ggml-opt.h"
|
10
11
|
|
11
12
|
#include <map>
|
12
13
|
#include <vector>
|
@@ -133,6 +134,32 @@ struct llama_context {
|
|
133
134
|
llama_perf_context_data perf_get_data() const;
|
134
135
|
void perf_reset();
|
135
136
|
|
137
|
+
//
|
138
|
+
// training
|
139
|
+
//
|
140
|
+
|
141
|
+
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
142
|
+
|
143
|
+
void opt_epoch(
|
144
|
+
lm_ggml_opt_dataset_t dataset,
|
145
|
+
lm_ggml_opt_result_t result_train,
|
146
|
+
lm_ggml_opt_result_t result_eval,
|
147
|
+
int64_t idata_split,
|
148
|
+
lm_ggml_opt_epoch_callback callback_train,
|
149
|
+
lm_ggml_opt_epoch_callback callback_eval);
|
150
|
+
|
151
|
+
void opt_epoch_iter(
|
152
|
+
lm_ggml_opt_dataset_t dataset,
|
153
|
+
lm_ggml_opt_result_t result,
|
154
|
+
const std::vector<llama_token> & tokens,
|
155
|
+
const std::vector<llama_token> & labels_sparse,
|
156
|
+
llama_batch & batch,
|
157
|
+
lm_ggml_opt_epoch_callback callback,
|
158
|
+
bool train,
|
159
|
+
int64_t idata_in_loop,
|
160
|
+
int64_t ndata_in_loop,
|
161
|
+
int64_t t_loop_start);
|
162
|
+
|
136
163
|
private:
|
137
164
|
//
|
138
165
|
// output
|
@@ -187,9 +214,6 @@ private:
|
|
187
214
|
|
188
215
|
std::unique_ptr<llama_memory_i> memory;
|
189
216
|
|
190
|
-
// TODO: remove
|
191
|
-
bool logits_all = false;
|
192
|
-
|
193
217
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
194
218
|
size_t logits_size = 0; // capacity (of floats) for logits
|
195
219
|
float * logits = nullptr;
|
@@ -215,6 +239,9 @@ private:
|
|
215
239
|
|
216
240
|
lm_ggml_context_ptr ctx_compute;
|
217
241
|
|
242
|
+
// training
|
243
|
+
lm_ggml_opt_context_t opt_ctx = nullptr;
|
244
|
+
|
218
245
|
lm_ggml_threadpool_t threadpool = nullptr;
|
219
246
|
lm_ggml_threadpool_t threadpool_batch = nullptr;
|
220
247
|
|
package/cpp/llama-cparams.h
CHANGED
package/cpp/llama-graph.cpp
CHANGED
@@ -9,33 +9,6 @@
|
|
9
9
|
#include <cmath>
|
10
10
|
#include <cstring>
|
11
11
|
|
12
|
-
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
13
|
-
// TODO move to hparams if a T5 variant appears that uses a different value
|
14
|
-
const int64_t max_distance = 128;
|
15
|
-
|
16
|
-
if (bidirectional) {
|
17
|
-
n_buckets >>= 1;
|
18
|
-
}
|
19
|
-
|
20
|
-
const int64_t max_exact = n_buckets >> 1;
|
21
|
-
|
22
|
-
int32_t relative_position = x - y;
|
23
|
-
int32_t relative_bucket = 0;
|
24
|
-
|
25
|
-
if (bidirectional) {
|
26
|
-
relative_bucket += (relative_position > 0) * n_buckets;
|
27
|
-
relative_position = abs(relative_position);
|
28
|
-
} else {
|
29
|
-
relative_position = -std::min<int32_t>(relative_position, 0);
|
30
|
-
}
|
31
|
-
|
32
|
-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
33
|
-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
34
|
-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
35
|
-
|
36
|
-
return relative_bucket;
|
37
|
-
}
|
38
|
-
|
39
12
|
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
40
13
|
if (ubatch->token) {
|
41
14
|
const int64_t n_tokens = ubatch->n_tokens;
|
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
110
83
|
|
111
84
|
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
112
85
|
if (pos_bucket) {
|
113
|
-
|
114
|
-
|
115
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(pos_bucket->buffer));
|
116
|
-
LM_GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
117
|
-
|
118
|
-
int32_t * data = (int32_t *) pos_bucket->data;
|
119
|
-
|
120
|
-
const int64_t n_kv = kv_self->n;
|
121
|
-
|
122
|
-
for (int h = 0; h < 1; ++h) {
|
123
|
-
for (int j = 0; j < n_tokens; ++j) {
|
124
|
-
for (int i = 0; i < n_kv; ++i) {
|
125
|
-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
126
|
-
}
|
127
|
-
}
|
128
|
-
}
|
86
|
+
kv_self->set_input_pos_bucket(pos_bucket, ubatch);
|
129
87
|
}
|
130
88
|
}
|
131
89
|
|
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
403
361
|
}
|
404
362
|
|
405
363
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
406
|
-
if (self_kq_mask
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
411
|
-
|
412
|
-
float * data = nullptr;
|
413
|
-
float * data_swa = nullptr;
|
414
|
-
|
415
|
-
if (self_kq_mask) {
|
416
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
417
|
-
data = (float *) self_kq_mask->data;
|
418
|
-
}
|
419
|
-
|
420
|
-
if (self_kq_mask_swa) {
|
421
|
-
LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
422
|
-
data_swa = (float *) self_kq_mask_swa->data;
|
423
|
-
}
|
424
|
-
|
425
|
-
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
426
|
-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
427
|
-
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
428
|
-
// Causal mask:
|
429
|
-
// xxx-------
|
430
|
-
// xxxx------
|
431
|
-
// xxxxx-----
|
432
|
-
// Non-causal mask:
|
433
|
-
// xxxxx-----
|
434
|
-
// xxxxx-----
|
435
|
-
// xxxxx-----
|
436
|
-
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
437
|
-
for (int h = 0; h < 1; ++h) {
|
438
|
-
for (int s = 0; s < n_seqs; ++s) {
|
439
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
440
|
-
|
441
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
442
|
-
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
443
|
-
for (int i = 0; i < n_kv; ++i) {
|
444
|
-
float f;
|
445
|
-
// mask the token if:
|
446
|
-
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
447
|
-
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
448
|
-
) {
|
449
|
-
f = -INFINITY;
|
450
|
-
} else {
|
451
|
-
if (hparams.use_alibi) {
|
452
|
-
f = -std::abs(kv_self->cells[i].pos - pos);
|
453
|
-
} else {
|
454
|
-
f = 0.0f;
|
455
|
-
}
|
456
|
-
}
|
457
|
-
|
458
|
-
if (data) {
|
459
|
-
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
460
|
-
}
|
461
|
-
|
462
|
-
// may need to cut off old tokens for sliding window
|
463
|
-
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
464
|
-
if (data_swa) {
|
465
|
-
if (hparams.n_attn_chunk) {
|
466
|
-
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
467
|
-
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
468
|
-
f = -INFINITY;
|
469
|
-
}
|
470
|
-
} else {
|
471
|
-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
472
|
-
f = -INFINITY;
|
473
|
-
}
|
474
|
-
}
|
475
|
-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
476
|
-
}
|
477
|
-
}
|
478
|
-
}
|
479
|
-
}
|
364
|
+
if (self_kq_mask) {
|
365
|
+
kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
366
|
+
}
|
367
|
+
}
|
480
368
|
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
486
|
-
}
|
487
|
-
}
|
488
|
-
}
|
369
|
+
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
370
|
+
if (self_kq_mask) {
|
371
|
+
kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
372
|
+
}
|
489
373
|
|
490
|
-
|
491
|
-
|
492
|
-
for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
|
493
|
-
for (int j = 0; j < n_kv; ++j) {
|
494
|
-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
495
|
-
}
|
496
|
-
}
|
497
|
-
}
|
498
|
-
}
|
374
|
+
if (self_kq_mask_swa) {
|
375
|
+
kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
499
376
|
}
|
500
377
|
}
|
501
378
|
|
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
545
422
|
n_layer (hparams.n_layer),
|
546
423
|
n_rot (hparams.n_rot),
|
547
424
|
n_ctx (cparams.n_ctx),
|
548
|
-
n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
|
549
425
|
n_head (hparams.n_head()),
|
550
426
|
n_head_kv (hparams.n_head_kv()),
|
551
427
|
n_embd_head_k (hparams.n_embd_head_k),
|
@@ -782,7 +658,7 @@ lm_ggml_tensor * llm_graph_context::build_ffn(
|
|
782
658
|
} break;
|
783
659
|
}
|
784
660
|
|
785
|
-
if (type_gate == LLM_FFN_PAR) {
|
661
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
786
662
|
cur = lm_ggml_mul(ctx0, cur, tmp);
|
787
663
|
cb(cur, "ffn_gate_par", il);
|
788
664
|
}
|
@@ -971,6 +847,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_embd(lm_ggml_tensor * tok_embd) co
|
|
971
847
|
inp->tokens = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, ubatch.n_tokens);
|
972
848
|
//cb(inp->tokens, "inp_tokens", -1);
|
973
849
|
lm_ggml_set_input(inp->tokens);
|
850
|
+
res->t_tokens = inp->tokens;
|
974
851
|
|
975
852
|
cur = lm_ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
976
853
|
|
@@ -1152,7 +1029,7 @@ lm_ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1152
1029
|
|
1153
1030
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
|
1154
1031
|
|
1155
|
-
const auto n_kv = kv_self->
|
1032
|
+
const auto n_kv = kv_self->get_n();
|
1156
1033
|
|
1157
1034
|
auto & cur = inp->pos_bucket;
|
1158
1035
|
|
@@ -1187,16 +1064,12 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1187
1064
|
lm_ggml_tensor * kq_b,
|
1188
1065
|
lm_ggml_tensor * kq_mask,
|
1189
1066
|
lm_ggml_tensor * v_mla,
|
1190
|
-
bool v_trans,
|
1191
1067
|
float kq_scale) const {
|
1192
|
-
|
1193
|
-
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
1068
|
+
const bool v_trans = v->nb[1] > v->nb[2];
|
1194
1069
|
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
1199
|
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
1070
|
+
q = lm_ggml_permute(ctx0, q, 0, 2, 1, 3);
|
1071
|
+
k = lm_ggml_permute(ctx0, k, 0, 2, 1, 3);
|
1072
|
+
v = lm_ggml_permute(ctx0, v, 0, 2, 1, 3);
|
1200
1073
|
|
1201
1074
|
const auto n_tokens = q->ne[1];
|
1202
1075
|
const auto n_head = q->ne[2];
|
@@ -1227,8 +1100,19 @@ lm_ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1227
1100
|
lm_ggml_flash_attn_ext_set_prec(cur, LM_GGML_PREC_F32);
|
1228
1101
|
|
1229
1102
|
if (v_mla) {
|
1103
|
+
#if 0
|
1104
|
+
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
1105
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
1230
1106
|
cur = lm_ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
1231
1107
|
cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
|
1108
|
+
#else
|
1109
|
+
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
1110
|
+
// The permutations are noops and only change how the tensor data is interpreted.
|
1111
|
+
cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1112
|
+
cur = lm_ggml_mul_mat(ctx0, v_mla, cur);
|
1113
|
+
cur = lm_ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
1114
|
+
cur = lm_ggml_cont(ctx0, cur); // Needed because lm_ggml_reshape_2d expects contiguous inputs.
|
1115
|
+
#endif
|
1232
1116
|
}
|
1233
1117
|
|
1234
1118
|
cur = lm_ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
@@ -1324,17 +1208,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1324
1208
|
|
1325
1209
|
const auto & kq_mask = inp->get_kq_mask();
|
1326
1210
|
|
1327
|
-
lm_ggml_tensor * q =
|
1328
|
-
|
1329
|
-
|
1330
|
-
lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
1331
|
-
//cb(k, "k", il);
|
1332
|
-
|
1333
|
-
lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
1334
|
-
//cb(k, "v", il);
|
1335
|
-
|
1336
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
1211
|
+
lm_ggml_tensor * q = q_cur;
|
1212
|
+
lm_ggml_tensor * k = k_cur;
|
1213
|
+
lm_ggml_tensor * v = v_cur;
|
1337
1214
|
|
1215
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1338
1216
|
cb(cur, "kqv_out", il);
|
1339
1217
|
|
1340
1218
|
if (wo) {
|
@@ -1357,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
1357
1235
|
|
1358
1236
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
|
1359
1237
|
|
1360
|
-
|
1361
|
-
|
1362
|
-
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1363
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1364
|
-
lm_ggml_set_input(inp->self_kq_mask);
|
1365
|
-
|
1366
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1238
|
+
{
|
1239
|
+
LM_GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
1367
1240
|
|
1368
|
-
|
1369
|
-
LM_GGML_ASSERT(hparams.n_swa > 0);
|
1241
|
+
const auto n_kv = kv_self->get_n();
|
1370
1242
|
|
1371
|
-
inp->
|
1372
|
-
//cb(inp->
|
1373
|
-
lm_ggml_set_input(inp->
|
1243
|
+
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1244
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1245
|
+
lm_ggml_set_input(inp->self_kq_mask);
|
1374
1246
|
|
1375
|
-
inp->
|
1247
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1376
1248
|
}
|
1377
1249
|
|
1378
1250
|
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
@@ -1397,85 +1269,108 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1397
1269
|
lm_ggml_build_forward_expand(gf, v_cur);
|
1398
1270
|
|
1399
1271
|
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
|
1400
|
-
const auto & n_ctx = cparams.n_ctx;
|
1401
1272
|
|
1402
|
-
|
1403
|
-
|
1273
|
+
// store to KV cache
|
1274
|
+
{
|
1275
|
+
lm_ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
|
1276
|
+
lm_ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
|
1277
|
+
}
|
1404
1278
|
|
1405
|
-
const auto
|
1279
|
+
const auto & kq_mask = inp->get_kq_mask();
|
1406
1280
|
|
1407
|
-
|
1281
|
+
lm_ggml_tensor * q = q_cur;
|
1282
|
+
lm_ggml_tensor * k = kv_self->get_k(ctx0, il);
|
1283
|
+
lm_ggml_tensor * v = kv_self->get_v(ctx0, il);
|
1408
1284
|
|
1409
|
-
|
1410
|
-
|
1411
|
-
const auto kv_head = kv_self->head;
|
1285
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1286
|
+
cb(cur, "kqv_out", il);
|
1412
1287
|
|
1413
|
-
|
1288
|
+
if (wo) {
|
1289
|
+
cur = build_lora_mm(wo, cur);
|
1290
|
+
}
|
1414
1291
|
|
1415
|
-
|
1416
|
-
|
1292
|
+
if (wo_b) {
|
1293
|
+
cur = lm_ggml_add(ctx0, cur, wo_b);
|
1294
|
+
}
|
1417
1295
|
|
1418
|
-
|
1419
|
-
|
1296
|
+
return cur;
|
1297
|
+
}
|
1420
1298
|
|
1421
|
-
|
1299
|
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
1300
|
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
1422
1301
|
|
1423
|
-
|
1302
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
|
1424
1303
|
|
1425
|
-
|
1426
|
-
|
1427
|
-
} else {
|
1428
|
-
// note: the V cache is transposed when not using flash attention
|
1429
|
-
v_cache_view = lm_ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
1430
|
-
( n_ctx)*lm_ggml_element_size(kv_self->v_l[il]),
|
1431
|
-
(kv_head)*lm_ggml_element_size(kv_self->v_l[il]));
|
1304
|
+
{
|
1305
|
+
const auto n_kv = kv_self->get_kv_base()->get_n();
|
1432
1306
|
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1307
|
+
inp->self_kq_mask = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1308
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
1309
|
+
lm_ggml_set_input(inp->self_kq_mask);
|
1436
1310
|
|
1437
|
-
|
1311
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask, LM_GGML_TYPE_F16) : inp->self_kq_mask;
|
1438
1312
|
}
|
1439
1313
|
|
1314
|
+
{
|
1315
|
+
LM_GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
1316
|
+
|
1317
|
+
const auto n_kv = kv_self->get_kv_swa()->get_n();
|
1318
|
+
|
1319
|
+
inp->self_kq_mask_swa = lm_ggml_new_tensor_2d(ctx0, LM_GGML_TYPE_F32, n_kv, LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD));
|
1320
|
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
1321
|
+
lm_ggml_set_input(inp->self_kq_mask_swa);
|
1322
|
+
|
1323
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? lm_ggml_cast(ctx0, inp->self_kq_mask_swa, LM_GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
1324
|
+
}
|
1325
|
+
|
1326
|
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
1327
|
+
}
|
1328
|
+
|
1329
|
+
lm_ggml_tensor * llm_graph_context::build_attn(
|
1330
|
+
llm_graph_input_attn_kv_unified_iswa * inp,
|
1331
|
+
lm_ggml_cgraph * gf,
|
1332
|
+
lm_ggml_tensor * wo,
|
1333
|
+
lm_ggml_tensor * wo_b,
|
1334
|
+
lm_ggml_tensor * q_cur,
|
1335
|
+
lm_ggml_tensor * k_cur,
|
1336
|
+
lm_ggml_tensor * v_cur,
|
1337
|
+
lm_ggml_tensor * kq_b,
|
1338
|
+
lm_ggml_tensor * v_mla,
|
1339
|
+
float kq_scale,
|
1340
|
+
int il) const {
|
1341
|
+
// these nodes are added to the graph together so that they are not reordered
|
1342
|
+
// by doing so, the number of splits in the graph is reduced
|
1343
|
+
lm_ggml_build_forward_expand(gf, q_cur);
|
1344
|
+
lm_ggml_build_forward_expand(gf, k_cur);
|
1345
|
+
lm_ggml_build_forward_expand(gf, v_cur);
|
1346
|
+
|
1440
1347
|
const bool is_swa = hparams.is_swa(il);
|
1441
1348
|
|
1349
|
+
const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
|
1350
|
+
|
1351
|
+
const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
|
1352
|
+
|
1353
|
+
// store to KV cache
|
1354
|
+
{
|
1355
|
+
lm_ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
|
1356
|
+
lm_ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
|
1357
|
+
}
|
1358
|
+
|
1442
1359
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
1443
1360
|
|
1444
|
-
|
1361
|
+
lm_ggml_tensor * q = q_cur;
|
1362
|
+
lm_ggml_tensor * k = kv->get_k(ctx0, il);
|
1363
|
+
lm_ggml_tensor * v = kv->get_v(ctx0, il);
|
1445
1364
|
|
1446
|
-
|
1447
|
-
|
1448
|
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
1449
|
-
const auto & n_embd_head_v = hparams.n_embd_head_v;
|
1450
|
-
|
1451
|
-
lm_ggml_tensor * q = lm_ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
1452
|
-
//cb(q, "q", il);
|
1453
|
-
|
1454
|
-
lm_ggml_tensor * k =
|
1455
|
-
lm_ggml_view_3d(ctx0, kv_self->k_l[il],
|
1456
|
-
n_embd_head_k, n_kv, n_head_kv,
|
1457
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
1458
|
-
lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
1459
|
-
0);
|
1460
|
-
//cb(k, "k", il);
|
1461
|
-
|
1462
|
-
lm_ggml_tensor * v = !v_trans ?
|
1463
|
-
lm_ggml_view_3d(ctx0, kv_self->v_l[il],
|
1464
|
-
n_embd_head_v, n_kv, n_head_kv,
|
1465
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
1466
|
-
lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
1467
|
-
0) :
|
1468
|
-
lm_ggml_view_3d(ctx0, kv_self->v_l[il],
|
1469
|
-
n_kv, n_embd_head_v, n_head_kv,
|
1470
|
-
lm_ggml_element_size(kv_self->v_l[il])*n_ctx,
|
1471
|
-
lm_ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
1472
|
-
0);
|
1473
|
-
|
1474
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
1365
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1475
1366
|
cb(cur, "kqv_out", il);
|
1476
1367
|
|
1477
1368
|
if (wo) {
|
1478
1369
|
cur = build_lora_mm(wo, cur);
|
1370
|
+
if (arch == LLM_ARCH_GLM4) {
|
1371
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
1372
|
+
lm_ggml_mul_mat_set_prec(cur, LM_GGML_PREC_F32);
|
1373
|
+
}
|
1479
1374
|
}
|
1480
1375
|
|
1481
1376
|
if (wo_b) {
|
@@ -1522,17 +1417,11 @@ lm_ggml_tensor * llm_graph_context::build_attn(
|
|
1522
1417
|
|
1523
1418
|
const auto & kq_mask = inp->get_kq_mask_cross();
|
1524
1419
|
|
1525
|
-
lm_ggml_tensor * q =
|
1526
|
-
|
1527
|
-
|
1528
|
-
lm_ggml_tensor * k = lm_ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
|
1529
|
-
//cb(k, "k", il);
|
1530
|
-
|
1531
|
-
lm_ggml_tensor * v = lm_ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
1532
|
-
//cb(k, "v", il);
|
1533
|
-
|
1534
|
-
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
1420
|
+
lm_ggml_tensor * q = q_cur;
|
1421
|
+
lm_ggml_tensor * k = k_cur;
|
1422
|
+
lm_ggml_tensor * v = v_cur;
|
1535
1423
|
|
1424
|
+
lm_ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
1536
1425
|
cb(cur, "kqv_out", il);
|
1537
1426
|
|
1538
1427
|
if (wo) {
|
@@ -1700,3 +1589,30 @@ void llm_graph_context::build_pooling(
|
|
1700
1589
|
|
1701
1590
|
lm_ggml_build_forward_expand(gf, cur);
|
1702
1591
|
}
|
1592
|
+
|
1593
|
+
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
1594
|
+
// TODO move to hparams if a T5 variant appears that uses a different value
|
1595
|
+
const int64_t max_distance = 128;
|
1596
|
+
|
1597
|
+
if (bidirectional) {
|
1598
|
+
n_buckets >>= 1;
|
1599
|
+
}
|
1600
|
+
|
1601
|
+
const int64_t max_exact = n_buckets >> 1;
|
1602
|
+
|
1603
|
+
int32_t relative_position = x - y;
|
1604
|
+
int32_t relative_bucket = 0;
|
1605
|
+
|
1606
|
+
if (bidirectional) {
|
1607
|
+
relative_bucket += (relative_position > 0) * n_buckets;
|
1608
|
+
relative_position = abs(relative_position);
|
1609
|
+
} else {
|
1610
|
+
relative_position = -std::min<int32_t>(relative_position, 0);
|
1611
|
+
}
|
1612
|
+
|
1613
|
+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
1614
|
+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
1615
|
+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
1616
|
+
|
1617
|
+
return relative_bucket;
|
1618
|
+
}
|