@novastera-oss/llamarn 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +4 -3
- package/cpp/llama.cpp/common/arg.cpp +45 -1
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +18 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
- package/cpp/llama.cpp/include/llama.h +15 -7
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
- package/cpp/llama.cpp/src/llama-arch.h +5 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
- package/cpp/llama.cpp/src/llama-batch.h +24 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
- package/cpp/llama.cpp/src/llama-chat.h +2 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
- package/cpp/llama.cpp/src/llama-graph.h +147 -72
- package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
- package/cpp/llama.cpp/src/llama-hparams.h +10 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
- package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
- package/cpp/llama.cpp/src/llama-model.h +3 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
- package/cpp/llama.cpp/src/llama-vocab.h +2 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/common.h +18 -4
- package/ios/include/llama.h +15 -7
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
|
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
|
|
28
28
|
}
|
|
29
29
|
}
|
|
30
30
|
|
|
31
|
+
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
|
32
|
+
bool res = true;
|
|
33
|
+
|
|
34
|
+
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
|
35
|
+
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
|
|
36
|
+
|
|
37
|
+
return res;
|
|
38
|
+
}
|
|
39
|
+
|
|
31
40
|
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
32
41
|
if (ubatch->pos && pos) {
|
|
33
42
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
|
50
59
|
}
|
|
51
60
|
}
|
|
52
61
|
|
|
62
|
+
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
|
|
63
|
+
bool res = true;
|
|
64
|
+
|
|
65
|
+
res &= pos->ne[0] == params.ubatch.n_tokens;
|
|
66
|
+
|
|
67
|
+
return res;
|
|
68
|
+
}
|
|
69
|
+
|
|
53
70
|
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|
54
71
|
if (ubatch->pos && attn_scale) {
|
|
55
72
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
71
88
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
72
89
|
|
|
73
90
|
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
|
|
74
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
91
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
|
75
92
|
|
|
76
93
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
77
94
|
|
|
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
|
|
118
135
|
}
|
|
119
136
|
}
|
|
120
137
|
|
|
138
|
+
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
|
139
|
+
bool res = true;
|
|
140
|
+
|
|
141
|
+
res &= n_outputs == params.n_outputs;
|
|
142
|
+
|
|
143
|
+
return res;
|
|
144
|
+
}
|
|
145
|
+
|
|
121
146
|
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
122
147
|
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
|
123
148
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
|
287
312
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
288
313
|
}
|
|
289
314
|
|
|
315
|
+
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
|
|
316
|
+
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
|
|
317
|
+
|
|
318
|
+
this->mctx = mctx;
|
|
319
|
+
|
|
320
|
+
bool res = true;
|
|
321
|
+
|
|
322
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
323
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
324
|
+
|
|
325
|
+
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
|
326
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
327
|
+
|
|
328
|
+
res &= mctx->get_supports_set_rows(); // TODO: tmp
|
|
329
|
+
|
|
330
|
+
return res;
|
|
331
|
+
}
|
|
332
|
+
|
|
290
333
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
|
291
334
|
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
292
335
|
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
|
|
299
342
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
300
343
|
}
|
|
301
344
|
|
|
345
|
+
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
|
|
346
|
+
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
|
|
347
|
+
|
|
348
|
+
this->mctx = mctx;
|
|
349
|
+
|
|
350
|
+
bool res = true;
|
|
351
|
+
|
|
352
|
+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
|
353
|
+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
354
|
+
|
|
355
|
+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
|
356
|
+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
|
357
|
+
|
|
358
|
+
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
|
359
|
+
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
360
|
+
|
|
361
|
+
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
|
362
|
+
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
363
|
+
|
|
364
|
+
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
|
|
365
|
+
|
|
366
|
+
return res;
|
|
367
|
+
}
|
|
368
|
+
|
|
302
369
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
303
370
|
GGML_ASSERT(cross_kq_mask);
|
|
304
371
|
|
|
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
306
373
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
307
374
|
|
|
308
375
|
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
|
309
|
-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
|
376
|
+
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
|
|
310
377
|
|
|
311
378
|
float * data = (float *) cross_kq_mask->data;
|
|
312
379
|
|
|
@@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
|
340
407
|
inp_rs->set_input(ubatch);
|
|
341
408
|
}
|
|
342
409
|
|
|
410
|
+
//
|
|
411
|
+
// llm_graph_result
|
|
412
|
+
//
|
|
413
|
+
|
|
414
|
+
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
|
|
415
|
+
reset();
|
|
416
|
+
|
|
417
|
+
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
|
|
418
|
+
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
int64_t llm_graph_result::get_max_nodes() const {
|
|
422
|
+
return max_nodes;
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
void llm_graph_result::reset() {
|
|
426
|
+
t_tokens = nullptr;
|
|
427
|
+
t_logits = nullptr;
|
|
428
|
+
t_embd = nullptr;
|
|
429
|
+
t_embd_pooled = nullptr;
|
|
430
|
+
|
|
431
|
+
params = {};
|
|
432
|
+
|
|
433
|
+
inputs.clear();
|
|
434
|
+
|
|
435
|
+
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
|
436
|
+
|
|
437
|
+
ggml_init_params params = {
|
|
438
|
+
/*.mem_size =*/ buf_compute_meta.size(),
|
|
439
|
+
/*.mem_buffer =*/ buf_compute_meta.data(),
|
|
440
|
+
/*.no_alloc =*/ true,
|
|
441
|
+
};
|
|
442
|
+
|
|
443
|
+
ctx_compute.reset(ggml_init(params));
|
|
444
|
+
|
|
445
|
+
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
|
|
449
|
+
for (auto & input : inputs) {
|
|
450
|
+
input->set_input(ubatch);
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
|
|
455
|
+
if (!this->params.allow_reuse(params)) {
|
|
456
|
+
if (debug > 1) {
|
|
457
|
+
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
return false;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
if (debug > 1) {
|
|
464
|
+
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
bool res = true;
|
|
468
|
+
|
|
469
|
+
for (auto & input : inputs) {
|
|
470
|
+
const bool cur = input->can_reuse(params);
|
|
471
|
+
|
|
472
|
+
if (debug > 1) {
|
|
473
|
+
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
res = res && cur;
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
if (debug > 0) {
|
|
480
|
+
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
return res;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
|
|
487
|
+
inputs.emplace_back(std::move(input));
|
|
488
|
+
return inputs.back().get();
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
void llm_graph_result::set_params(const llm_graph_params & params) {
|
|
492
|
+
this->params = params;
|
|
493
|
+
}
|
|
494
|
+
|
|
343
495
|
//
|
|
344
496
|
// llm_graph_context
|
|
345
497
|
//
|
|
@@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
374
526
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
|
375
527
|
pooling_type (cparams.pooling_type),
|
|
376
528
|
rope_type (hparams.rope_type),
|
|
377
|
-
ctx0 (params.ctx),
|
|
378
529
|
sched (params.sched),
|
|
379
530
|
backend_cpu (params.backend_cpu),
|
|
380
531
|
cvec (params.cvec),
|
|
@@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
382
533
|
mctx (params.mctx),
|
|
383
534
|
cross (params.cross),
|
|
384
535
|
cb_func (params.cb),
|
|
385
|
-
res (
|
|
536
|
+
res (params.res),
|
|
537
|
+
ctx0 (res->get_ctx()),
|
|
538
|
+
gf (res->get_gf()) {
|
|
539
|
+
res->set_params(params);
|
|
386
540
|
}
|
|
387
541
|
|
|
388
542
|
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
|
@@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
753
907
|
cb(cur, "ffn_moe_weighted", il);
|
|
754
908
|
}
|
|
755
909
|
|
|
910
|
+
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
|
911
|
+
|
|
912
|
+
assert(n_expert_used > 0);
|
|
913
|
+
|
|
914
|
+
// order the views before the adds
|
|
915
|
+
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
|
916
|
+
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
|
917
|
+
|
|
918
|
+
ggml_build_forward_expand(gf, cur_experts[i]);
|
|
919
|
+
}
|
|
920
|
+
|
|
756
921
|
// aggregate experts
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
922
|
+
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
|
923
|
+
// to avoid potentially a large number of add nodes during warmup
|
|
924
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
|
925
|
+
ggml_tensor * moe_out = cur_experts[0];
|
|
761
926
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
} else {
|
|
765
|
-
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
|
766
|
-
}
|
|
927
|
+
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
|
928
|
+
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
|
767
929
|
}
|
|
768
930
|
|
|
769
|
-
if (n_expert_used == 1) {
|
|
931
|
+
if (hparams.n_expert_used == 1) {
|
|
770
932
|
// avoid returning a non-contiguous tensor
|
|
771
933
|
moe_out = ggml_cont(ctx0, moe_out);
|
|
772
934
|
}
|
|
@@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
|
972
1134
|
}
|
|
973
1135
|
|
|
974
1136
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
975
|
-
ggml_cgraph * gf,
|
|
976
1137
|
ggml_tensor * q,
|
|
977
1138
|
ggml_tensor * k,
|
|
978
1139
|
ggml_tensor * v,
|
|
@@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
982
1143
|
float kq_scale) const {
|
|
983
1144
|
const bool v_trans = v->nb[1] > v->nb[2];
|
|
984
1145
|
|
|
1146
|
+
// split the batch into streams if needed
|
|
1147
|
+
const auto n_stream = k->ne[3];
|
|
1148
|
+
|
|
1149
|
+
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
|
1150
|
+
|
|
985
1151
|
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
|
986
1152
|
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
|
987
1153
|
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
|
988
1154
|
|
|
989
|
-
const auto
|
|
990
|
-
const auto n_head = q->ne[2];
|
|
991
|
-
const auto n_kv = k->ne[1];
|
|
1155
|
+
const auto n_kv = k->ne[1];
|
|
992
1156
|
|
|
993
1157
|
ggml_tensor * cur;
|
|
994
1158
|
|
|
@@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1030
1194
|
#endif
|
|
1031
1195
|
}
|
|
1032
1196
|
|
|
1033
|
-
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*
|
|
1197
|
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
1034
1198
|
} else {
|
|
1035
1199
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
1036
1200
|
|
|
@@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1075
1239
|
|
|
1076
1240
|
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
1077
1241
|
|
|
1078
|
-
|
|
1242
|
+
// recombine streams
|
|
1243
|
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
|
1079
1244
|
|
|
1080
1245
|
if (!cparams.offload_kqv) {
|
|
1081
1246
|
// all nodes between the KV store and the attention output are run on the CPU
|
|
@@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
|
1102
1267
|
|
|
1103
1268
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1104
1269
|
llm_graph_input_attn_no_cache * inp,
|
|
1105
|
-
ggml_cgraph * gf,
|
|
1106
1270
|
ggml_tensor * wo,
|
|
1107
1271
|
ggml_tensor * wo_b,
|
|
1108
1272
|
ggml_tensor * q_cur,
|
|
@@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1122
1286
|
|
|
1123
1287
|
const auto & kq_mask = inp->get_kq_mask();
|
|
1124
1288
|
|
|
1289
|
+
// [TAG_NO_CACHE_PAD]
|
|
1290
|
+
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
|
1291
|
+
assert(!ubatch.equal_seqs());
|
|
1292
|
+
|
|
1125
1293
|
ggml_tensor * q = q_cur;
|
|
1126
1294
|
ggml_tensor * k = k_cur;
|
|
1127
1295
|
ggml_tensor * v = v_cur;
|
|
1128
1296
|
|
|
1129
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1297
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1130
1298
|
cb(cur, "kqv_out", il);
|
|
1131
1299
|
|
|
1132
1300
|
if (wo) {
|
|
@@ -1156,13 +1324,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|
|
1156
1324
|
{
|
|
1157
1325
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
|
1158
1326
|
|
|
1159
|
-
const auto n_kv
|
|
1327
|
+
const auto n_kv = mctx_cur->get_n_kv();
|
|
1160
1328
|
const auto n_tokens = ubatch.n_tokens;
|
|
1329
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1161
1330
|
|
|
1162
1331
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
|
1163
1332
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
|
1164
1333
|
|
|
1165
|
-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
|
1334
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
|
1166
1335
|
ggml_set_input(inp->self_kq_mask);
|
|
1167
1336
|
|
|
1168
1337
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
|
1181
1350
|
|
|
1182
1351
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1183
1352
|
llm_graph_input_attn_kv_unified * inp,
|
|
1184
|
-
ggml_cgraph * gf,
|
|
1185
1353
|
ggml_tensor * wo,
|
|
1186
1354
|
ggml_tensor * wo_b,
|
|
1187
1355
|
ggml_tensor * q_cur,
|
|
@@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1214
1382
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1215
1383
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1216
1384
|
|
|
1217
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1385
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1218
1386
|
cb(cur, "kqv_out", il);
|
|
1219
1387
|
|
|
1220
1388
|
if (wo) {
|
|
@@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1234
1402
|
|
|
1235
1403
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1236
1404
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
1237
|
-
ggml_cgraph * gf,
|
|
1238
1405
|
ggml_tensor * wo,
|
|
1239
1406
|
ggml_tensor * wo_b,
|
|
1240
1407
|
ggml_tensor * q_cur,
|
|
@@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1281
1448
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1282
1449
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1283
1450
|
|
|
1284
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1451
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1285
1452
|
cb(cur, "kqv_out", il);
|
|
1286
1453
|
|
|
1287
1454
|
if (wo) {
|
|
@@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
|
1314
1481
|
|
|
1315
1482
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1316
1483
|
llm_graph_input_attn_cross * inp,
|
|
1317
|
-
ggml_cgraph * gf,
|
|
1318
1484
|
ggml_tensor * wo,
|
|
1319
1485
|
ggml_tensor * wo_b,
|
|
1320
1486
|
ggml_tensor * q_cur,
|
|
@@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1336
1502
|
ggml_tensor * k = k_cur;
|
|
1337
1503
|
ggml_tensor * v = v_cur;
|
|
1338
1504
|
|
|
1339
|
-
ggml_tensor * cur = build_attn_mha(
|
|
1505
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1340
1506
|
cb(cur, "kqv_out", il);
|
|
1341
1507
|
|
|
1342
1508
|
if (wo) {
|
|
@@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1362
1528
|
|
|
1363
1529
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
|
1364
1530
|
|
|
1531
|
+
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1532
|
+
|
|
1365
1533
|
{
|
|
1366
1534
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
|
1367
1535
|
|
|
1368
1536
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
|
1369
1537
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
|
1370
1538
|
|
|
1371
|
-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
|
1539
|
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
|
1372
1540
|
ggml_set_input(inp->self_kq_mask);
|
|
1373
1541
|
|
|
1374
1542
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1382
1550
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
|
1383
1551
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
|
1384
1552
|
|
|
1385
|
-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1,
|
|
1553
|
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
|
1386
1554
|
ggml_set_input(inp->self_kq_mask_swa);
|
|
1387
1555
|
|
|
1388
1556
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1392
1560
|
}
|
|
1393
1561
|
|
|
1394
1562
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1395
|
-
ggml_cgraph * gf,
|
|
1396
1563
|
ggml_tensor * s,
|
|
1397
1564
|
ggml_tensor * state_copy,
|
|
1398
1565
|
int32_t state_size,
|
|
@@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
|
1450
1617
|
|
|
1451
1618
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1452
1619
|
llm_graph_input_rs * inp,
|
|
1453
|
-
ggml_cgraph * gf,
|
|
1454
1620
|
ggml_tensor * s,
|
|
1455
1621
|
int32_t state_size,
|
|
1456
1622
|
int32_t n_seqs,
|
|
1457
1623
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1458
1624
|
const auto * kv_state = inp->mctx;
|
|
1459
1625
|
|
|
1460
|
-
return build_rs(
|
|
1626
|
+
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
|
1461
1627
|
}
|
|
1462
1628
|
|
|
1463
1629
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1464
1630
|
llm_graph_input_rs * inp,
|
|
1465
|
-
ggml_cgraph * gf,
|
|
1466
1631
|
const llama_ubatch & ubatch,
|
|
1467
|
-
|
|
1632
|
+
int il) const {
|
|
1468
1633
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1469
1634
|
|
|
1470
1635
|
const auto token_shift_count = hparams.token_shift_count;
|
|
@@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1474
1639
|
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
|
1475
1640
|
|
|
1476
1641
|
ggml_tensor * token_shift = build_rs(
|
|
1477
|
-
inp,
|
|
1642
|
+
inp, token_shift_all,
|
|
1478
1643
|
hparams.n_embd_r(), n_seqs);
|
|
1479
1644
|
|
|
1480
1645
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
@@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
|
1514
1679
|
}
|
|
1515
1680
|
|
|
1516
1681
|
void llm_graph_context::build_pooling(
|
|
1517
|
-
ggml_cgraph * gf,
|
|
1518
1682
|
ggml_tensor * cls,
|
|
1519
1683
|
ggml_tensor * cls_b,
|
|
1520
1684
|
ggml_tensor * cls_out,
|