@novastera-oss/llamarn 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +4 -3
- package/cpp/llama.cpp/common/arg.cpp +45 -1
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +18 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
- package/cpp/llama.cpp/include/llama.h +15 -7
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
- package/cpp/llama.cpp/src/llama-arch.h +5 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
- package/cpp/llama.cpp/src/llama-batch.h +24 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
- package/cpp/llama.cpp/src/llama-chat.h +2 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
- package/cpp/llama.cpp/src/llama-graph.h +147 -72
- package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
- package/cpp/llama.cpp/src/llama-hparams.h +10 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
- package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
- package/cpp/llama.cpp/src/llama-model.h +3 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
- package/cpp/llama.cpp/src/llama-vocab.h +2 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/common.h +18 -4
- package/ios/include/llama.h +15 -7
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include "llama-arch.h"
|
|
4
|
+
#include "llama-batch.h"
|
|
4
5
|
#include "llama-hparams.h"
|
|
5
6
|
#include "llama-adapter.h"
|
|
6
7
|
|
|
@@ -14,7 +15,6 @@ struct ggml_cgraph;
|
|
|
14
15
|
struct ggml_context;
|
|
15
16
|
struct ggml_tensor;
|
|
16
17
|
|
|
17
|
-
struct llama_ubatch;
|
|
18
18
|
struct llama_cparams;
|
|
19
19
|
|
|
20
20
|
struct llama_memory_context_i;
|
|
@@ -69,6 +69,8 @@ struct llama_cross {
|
|
|
69
69
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
|
70
70
|
};
|
|
71
71
|
|
|
72
|
+
struct llm_graph_params;
|
|
73
|
+
|
|
72
74
|
//
|
|
73
75
|
// llm_graph_input
|
|
74
76
|
//
|
|
@@ -78,11 +80,19 @@ public:
|
|
|
78
80
|
virtual ~llm_graph_input_i() = default;
|
|
79
81
|
|
|
80
82
|
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
|
83
|
+
|
|
84
|
+
// return true if the resulting input tensors using the provided graph parameters would be
|
|
85
|
+
// the same as the previous input tensors that we have currently stored in the object
|
|
86
|
+
virtual bool can_reuse(const llm_graph_params & params) {
|
|
87
|
+
// returning false here by default will prevent from reusing the graph if the check
|
|
88
|
+
// for the input type has not been implemented yet
|
|
89
|
+
GGML_UNUSED(params);
|
|
90
|
+
return false;
|
|
91
|
+
}
|
|
81
92
|
};
|
|
82
93
|
|
|
83
94
|
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
|
84
95
|
|
|
85
|
-
|
|
86
96
|
class llm_graph_input_embd : public llm_graph_input_i {
|
|
87
97
|
public:
|
|
88
98
|
llm_graph_input_embd() = default;
|
|
@@ -90,6 +100,8 @@ public:
|
|
|
90
100
|
|
|
91
101
|
void set_input(const llama_ubatch * ubatch) override;
|
|
92
102
|
|
|
103
|
+
bool can_reuse(const llm_graph_params & params) override;
|
|
104
|
+
|
|
93
105
|
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
|
94
106
|
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
|
95
107
|
};
|
|
@@ -101,6 +113,8 @@ public:
|
|
|
101
113
|
|
|
102
114
|
void set_input(const llama_ubatch * ubatch) override;
|
|
103
115
|
|
|
116
|
+
bool can_reuse(const llm_graph_params & params) override;
|
|
117
|
+
|
|
104
118
|
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
|
105
119
|
|
|
106
120
|
const uint32_t n_pos_per_embd = 1;
|
|
@@ -154,17 +168,19 @@ public:
|
|
|
154
168
|
llm_graph_input_out_ids(
|
|
155
169
|
const llama_hparams & hparams,
|
|
156
170
|
const llama_cparams & cparams,
|
|
157
|
-
|
|
171
|
+
uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
|
|
158
172
|
virtual ~llm_graph_input_out_ids() = default;
|
|
159
173
|
|
|
160
174
|
void set_input(const llama_ubatch * ubatch) override;
|
|
161
175
|
|
|
176
|
+
bool can_reuse(const llm_graph_params & params) override;
|
|
177
|
+
|
|
162
178
|
ggml_tensor * out_ids; // I32 [n_outputs]
|
|
163
179
|
|
|
164
180
|
const llama_hparams & hparams;
|
|
165
181
|
const llama_cparams & cparams;
|
|
166
182
|
|
|
167
|
-
const
|
|
183
|
+
const uint32_t n_outputs;
|
|
168
184
|
};
|
|
169
185
|
|
|
170
186
|
class llm_graph_input_mean : public llm_graph_input_i {
|
|
@@ -249,16 +265,18 @@ public:
|
|
|
249
265
|
|
|
250
266
|
void set_input(const llama_ubatch * ubatch) override;
|
|
251
267
|
|
|
268
|
+
bool can_reuse(const llm_graph_params & params) override;
|
|
269
|
+
|
|
252
270
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
253
271
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
254
272
|
|
|
255
273
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
256
274
|
|
|
257
275
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
258
|
-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
276
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
|
259
277
|
|
|
260
|
-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1,
|
|
261
|
-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1,
|
|
278
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
|
279
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
|
262
280
|
|
|
263
281
|
const llama_hparams & hparams;
|
|
264
282
|
const llama_cparams & cparams;
|
|
@@ -280,6 +298,8 @@ public:
|
|
|
280
298
|
|
|
281
299
|
void set_input(const llama_ubatch * ubatch) override;
|
|
282
300
|
|
|
301
|
+
bool can_reuse(const llm_graph_params & params) override;
|
|
302
|
+
|
|
283
303
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
|
284
304
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
|
285
305
|
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
|
@@ -289,14 +309,14 @@ public:
|
|
|
289
309
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
|
290
310
|
|
|
291
311
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
|
292
|
-
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
|
312
|
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
|
293
313
|
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
|
294
|
-
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
|
314
|
+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
|
295
315
|
|
|
296
|
-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1,
|
|
297
|
-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1,
|
|
298
|
-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1,
|
|
299
|
-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1,
|
|
316
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
|
317
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
|
318
|
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
|
319
|
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
|
300
320
|
|
|
301
321
|
const llama_hparams & hparams;
|
|
302
322
|
const llama_cparams & cparams;
|
|
@@ -351,40 +371,108 @@ public:
|
|
|
351
371
|
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
|
|
352
372
|
// these are used by the llama_context to extact the relevant data, based on the compute parameters
|
|
353
373
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
virtual ~llm_graph_result_i() = default;
|
|
374
|
+
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
|
375
|
+
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
|
|
357
376
|
|
|
358
|
-
|
|
359
|
-
virtual ggml_tensor * get_logits() = 0;
|
|
360
|
-
virtual ggml_tensor * get_embd() = 0;
|
|
361
|
-
virtual ggml_tensor * get_embd_pooled() = 0;
|
|
377
|
+
class llm_graph_result;
|
|
362
378
|
|
|
363
|
-
|
|
364
|
-
|
|
379
|
+
struct llm_graph_params {
|
|
380
|
+
llm_arch arch = LLM_ARCH_UNKNOWN;
|
|
365
381
|
|
|
366
|
-
|
|
382
|
+
llama_hparams hparams;
|
|
383
|
+
llama_cparams cparams;
|
|
367
384
|
|
|
385
|
+
llama_ubatch ubatch; // note: intentionally make a copy
|
|
368
386
|
|
|
369
|
-
|
|
370
|
-
public:
|
|
371
|
-
virtual ~llm_graph_result() = default;
|
|
387
|
+
llm_graph_type gtype;
|
|
372
388
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
ggml_tensor * get_embd() override { return t_embd; }
|
|
376
|
-
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
|
389
|
+
ggml_backend_sched_t sched;
|
|
390
|
+
ggml_backend_t backend_cpu;
|
|
377
391
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
392
|
+
const llama_adapter_cvec * cvec;
|
|
393
|
+
const llama_adapter_loras * loras;
|
|
394
|
+
const llama_memory_context_i * mctx;
|
|
395
|
+
const llama_cross * cross;
|
|
396
|
+
|
|
397
|
+
uint32_t n_outputs;
|
|
398
|
+
|
|
399
|
+
llm_graph_cb cb;
|
|
400
|
+
|
|
401
|
+
llm_graph_result * res;
|
|
402
|
+
|
|
403
|
+
// return true if the "other" params would result in a graph with the same topology as with the current params
|
|
404
|
+
// having the same topology allows us to reuse the graph in some cases
|
|
405
|
+
bool allow_reuse(const llm_graph_params & other) const {
|
|
406
|
+
// first check the ubatch
|
|
407
|
+
bool can_reuse_ubatch =
|
|
408
|
+
ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
|
|
409
|
+
ubatch.n_tokens == other.ubatch.n_tokens &&
|
|
410
|
+
ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
|
|
411
|
+
ubatch.n_seqs == other.ubatch.n_seqs &&
|
|
412
|
+
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
|
|
413
|
+
(
|
|
414
|
+
(!ubatch.token && !other.ubatch.token) ||
|
|
415
|
+
(!ubatch.embd && !other.ubatch.embd)
|
|
416
|
+
);
|
|
417
|
+
|
|
418
|
+
if (can_reuse_ubatch && !ubatch.equal_seqs()) {
|
|
419
|
+
if (!ubatch.data) {
|
|
420
|
+
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
|
|
421
|
+
// therefore we cannot perform the sequence id check. normally should never happen
|
|
422
|
+
can_reuse_ubatch = false;
|
|
423
|
+
} else {
|
|
424
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
425
|
+
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
|
|
426
|
+
}
|
|
427
|
+
}
|
|
381
428
|
}
|
|
382
|
-
}
|
|
383
429
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
430
|
+
if (!can_reuse_ubatch) {
|
|
431
|
+
return false;
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
return
|
|
435
|
+
cparams.embeddings == other.cparams.embeddings &&
|
|
436
|
+
cparams.causal_attn == other.cparams.causal_attn &&
|
|
437
|
+
arch == other.arch &&
|
|
438
|
+
gtype == other.gtype &&
|
|
439
|
+
cvec == other.cvec &&
|
|
440
|
+
loras == other.loras &&
|
|
441
|
+
cross == other.cross &&
|
|
442
|
+
n_outputs == other.n_outputs;
|
|
387
443
|
}
|
|
444
|
+
};
|
|
445
|
+
|
|
446
|
+
class llm_graph_result {
|
|
447
|
+
public:
|
|
448
|
+
llm_graph_result(int64_t max_nodes);
|
|
449
|
+
|
|
450
|
+
virtual ~llm_graph_result() = default;
|
|
451
|
+
|
|
452
|
+
ggml_tensor * get_tokens() const { return t_tokens; }
|
|
453
|
+
ggml_tensor * get_logits() const { return t_logits; }
|
|
454
|
+
ggml_tensor * get_embd() const { return t_embd; }
|
|
455
|
+
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
|
456
|
+
|
|
457
|
+
ggml_cgraph * get_gf() const { return gf; }
|
|
458
|
+
ggml_context * get_ctx() const { return ctx_compute.get(); }
|
|
459
|
+
|
|
460
|
+
int64_t get_max_nodes() const;
|
|
461
|
+
|
|
462
|
+
void reset();
|
|
463
|
+
|
|
464
|
+
void set_inputs(const llama_ubatch * ubatch);
|
|
465
|
+
|
|
466
|
+
// try to update the existing graph result using the new graph parameters in order to reuse it
|
|
467
|
+
// this can only be done if we determine that the resulting graph using the new graph parameters
|
|
468
|
+
// would be identical to the existing graph. in that case, we simply have to update the memory
|
|
469
|
+
// contexts of the input tensors of the graph and we can reuse it for another computation
|
|
470
|
+
// return true if the graph was updated and can be reused
|
|
471
|
+
bool can_reuse(const llm_graph_params & params);
|
|
472
|
+
|
|
473
|
+
llm_graph_input_i * add_input(llm_graph_input_ptr input);
|
|
474
|
+
|
|
475
|
+
void set_params(const llm_graph_params & params);
|
|
388
476
|
|
|
389
477
|
// important graph nodes
|
|
390
478
|
ggml_tensor * t_tokens = nullptr;
|
|
@@ -393,36 +481,31 @@ public:
|
|
|
393
481
|
ggml_tensor * t_embd_pooled = nullptr;
|
|
394
482
|
|
|
395
483
|
std::vector<llm_graph_input_ptr> inputs;
|
|
396
|
-
};
|
|
397
484
|
|
|
398
|
-
|
|
399
|
-
// llm_graph_context
|
|
400
|
-
//
|
|
485
|
+
ggml_context_ptr ctx_compute;
|
|
401
486
|
|
|
402
|
-
//
|
|
403
|
-
|
|
487
|
+
// memory buffers used to evaluate the model
|
|
488
|
+
std::vector<uint8_t> buf_compute_meta;
|
|
404
489
|
|
|
405
|
-
|
|
406
|
-
ggml_context * ctx;
|
|
490
|
+
ggml_cgraph * gf;
|
|
407
491
|
|
|
408
|
-
|
|
492
|
+
int64_t max_nodes;
|
|
409
493
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
494
|
+
private:
|
|
495
|
+
// keep a copy of the previous graph parameters
|
|
496
|
+
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
|
|
497
|
+
// note: these are updated after constructing the new graph
|
|
498
|
+
llm_graph_params params;
|
|
413
499
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
const llama_adapter_cvec * cvec;
|
|
418
|
-
const llama_adapter_loras * loras;
|
|
419
|
-
const llama_memory_context_i * mctx;
|
|
420
|
-
const llama_cross * cross;
|
|
500
|
+
// env: LLAMA_GRAPH_RESULT_DEBUG
|
|
501
|
+
int debug = 0;
|
|
502
|
+
};
|
|
421
503
|
|
|
422
|
-
|
|
504
|
+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
|
|
423
505
|
|
|
424
|
-
|
|
425
|
-
|
|
506
|
+
//
|
|
507
|
+
// llm_graph_context
|
|
508
|
+
//
|
|
426
509
|
|
|
427
510
|
// used in build_rs to properly order writes and avoid unnecessary copies
|
|
428
511
|
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
|
@@ -463,8 +546,6 @@ struct llm_graph_context {
|
|
|
463
546
|
const enum llama_pooling_type pooling_type;
|
|
464
547
|
const enum llama_rope_type rope_type;
|
|
465
548
|
|
|
466
|
-
ggml_context * ctx0 = nullptr;
|
|
467
|
-
|
|
468
549
|
ggml_backend_sched_t sched;
|
|
469
550
|
|
|
470
551
|
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
|
@@ -476,7 +557,10 @@ struct llm_graph_context {
|
|
|
476
557
|
|
|
477
558
|
const llm_graph_cb & cb_func;
|
|
478
559
|
|
|
479
|
-
|
|
560
|
+
llm_graph_result * res;
|
|
561
|
+
|
|
562
|
+
ggml_context * ctx0 = nullptr;
|
|
563
|
+
ggml_cgraph * gf = nullptr;
|
|
480
564
|
|
|
481
565
|
llm_graph_context(const llm_graph_params & params);
|
|
482
566
|
virtual ~llm_graph_context() = default;
|
|
@@ -562,7 +646,6 @@ struct llm_graph_context {
|
|
|
562
646
|
//
|
|
563
647
|
|
|
564
648
|
ggml_tensor * build_attn_mha(
|
|
565
|
-
ggml_cgraph * gf,
|
|
566
649
|
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
|
567
650
|
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
|
568
651
|
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
|
@@ -575,7 +658,6 @@ struct llm_graph_context {
|
|
|
575
658
|
|
|
576
659
|
ggml_tensor * build_attn(
|
|
577
660
|
llm_graph_input_attn_no_cache * inp,
|
|
578
|
-
ggml_cgraph * gf,
|
|
579
661
|
ggml_tensor * wo,
|
|
580
662
|
ggml_tensor * wo_b,
|
|
581
663
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -590,7 +672,6 @@ struct llm_graph_context {
|
|
|
590
672
|
|
|
591
673
|
ggml_tensor * build_attn(
|
|
592
674
|
llm_graph_input_attn_kv_unified * inp,
|
|
593
|
-
ggml_cgraph * gf,
|
|
594
675
|
ggml_tensor * wo,
|
|
595
676
|
ggml_tensor * wo_b,
|
|
596
677
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -606,7 +687,6 @@ struct llm_graph_context {
|
|
|
606
687
|
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
|
607
688
|
ggml_tensor * build_attn(
|
|
608
689
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
609
|
-
ggml_cgraph * gf,
|
|
610
690
|
ggml_tensor * wo,
|
|
611
691
|
ggml_tensor * wo_b,
|
|
612
692
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -621,7 +701,6 @@ struct llm_graph_context {
|
|
|
621
701
|
|
|
622
702
|
ggml_tensor * build_attn(
|
|
623
703
|
llm_graph_input_attn_cross * inp,
|
|
624
|
-
ggml_cgraph * gf,
|
|
625
704
|
ggml_tensor * wo,
|
|
626
705
|
ggml_tensor * wo_b,
|
|
627
706
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
@@ -643,7 +722,6 @@ struct llm_graph_context {
|
|
|
643
722
|
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
|
644
723
|
// `llama_memory_recurrent`
|
|
645
724
|
ggml_tensor * build_rs(
|
|
646
|
-
ggml_cgraph * gf,
|
|
647
725
|
ggml_tensor * s,
|
|
648
726
|
ggml_tensor * state_copy,
|
|
649
727
|
int32_t state_size,
|
|
@@ -658,7 +736,6 @@ struct llm_graph_context {
|
|
|
658
736
|
|
|
659
737
|
ggml_tensor * build_rs(
|
|
660
738
|
llm_graph_input_rs * inp,
|
|
661
|
-
ggml_cgraph * gf,
|
|
662
739
|
ggml_tensor * s,
|
|
663
740
|
int32_t state_size,
|
|
664
741
|
int32_t n_seqs,
|
|
@@ -666,9 +743,8 @@ struct llm_graph_context {
|
|
|
666
743
|
|
|
667
744
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
668
745
|
llm_graph_input_rs * inp,
|
|
669
|
-
ggml_cgraph * gf,
|
|
670
746
|
const llama_ubatch & ubatch,
|
|
671
|
-
|
|
747
|
+
int il) const;
|
|
672
748
|
|
|
673
749
|
ggml_tensor * build_rwkv_token_shift_store(
|
|
674
750
|
ggml_tensor * token_shift,
|
|
@@ -685,7 +761,6 @@ struct llm_graph_context {
|
|
|
685
761
|
//
|
|
686
762
|
|
|
687
763
|
void build_pooling(
|
|
688
|
-
ggml_cgraph * gf,
|
|
689
764
|
ggml_tensor * cls,
|
|
690
765
|
ggml_tensor * cls_b,
|
|
691
766
|
ggml_tensor * cls_out,
|
|
@@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
|
65
65
|
return n_embd_head_v * n_head_kv;
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
|
69
|
+
const uint32_t val = n_embd_k_gqa();
|
|
70
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
71
|
+
if (val != n_embd_k_gqa(il)) {
|
|
72
|
+
return true;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
return false;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
bool llama_hparams::is_n_embd_v_gqa_variable() const {
|
|
80
|
+
const uint32_t val = n_embd_v_gqa();
|
|
81
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
82
|
+
if (val != n_embd_v_gqa(il)) {
|
|
83
|
+
return true;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
return false;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
uint32_t llama_hparams::n_embd_k_gqa_max() const {
|
|
91
|
+
uint32_t val = n_embd_k_gqa();
|
|
92
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
93
|
+
val = std::max(val, n_embd_k_gqa(il));
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
return val;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
uint32_t llama_hparams::n_embd_v_gqa_max() const {
|
|
100
|
+
uint32_t val = n_embd_v_gqa();
|
|
101
|
+
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
102
|
+
val = std::max(val, n_embd_v_gqa(il));
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
return val;
|
|
106
|
+
}
|
|
107
|
+
|
|
68
108
|
uint32_t llama_hparams::n_embd_r() const {
|
|
69
109
|
if (wkv_head_size != 0) {
|
|
70
110
|
// for RWKV models
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
// bump if necessary
|
|
8
8
|
#define LLAMA_MAX_LAYERS 512
|
|
9
|
-
#define LLAMA_MAX_EXPERTS
|
|
9
|
+
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
|
10
10
|
|
|
11
11
|
enum llama_expert_gating_func_type {
|
|
12
12
|
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
|
@@ -98,7 +98,7 @@ struct llama_hparams {
|
|
|
98
98
|
float rope_freq_scale_train;
|
|
99
99
|
float rope_freq_scale_train_swa;
|
|
100
100
|
uint32_t n_ctx_orig_yarn;
|
|
101
|
-
float rope_yarn_log_mul;
|
|
101
|
+
float rope_yarn_log_mul = 0.0f;
|
|
102
102
|
|
|
103
103
|
std::array<int, 4> rope_sections;
|
|
104
104
|
|
|
@@ -191,6 +191,14 @@ struct llama_hparams {
|
|
|
191
191
|
// dimension of value embeddings across all k-v heads
|
|
192
192
|
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
|
193
193
|
|
|
194
|
+
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
|
|
195
|
+
bool is_n_embd_k_gqa_variable() const;
|
|
196
|
+
bool is_n_embd_v_gqa_variable() const;
|
|
197
|
+
|
|
198
|
+
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
|
|
199
|
+
uint32_t n_embd_k_gqa_max() const;
|
|
200
|
+
uint32_t n_embd_v_gqa_max() const;
|
|
201
|
+
|
|
194
202
|
// dimension of the rolling state embeddings
|
|
195
203
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
|
196
204
|
uint32_t n_embd_r() const;
|
|
@@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
|
18
18
|
bool v_trans,
|
|
19
19
|
bool offload,
|
|
20
20
|
bool swa_full,
|
|
21
|
+
bool unified,
|
|
21
22
|
uint32_t kv_size,
|
|
22
23
|
uint32_t n_seq_max,
|
|
23
24
|
uint32_t n_ubatch,
|
|
24
|
-
uint32_t n_pad) : hparams(model.hparams) {
|
|
25
|
+
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
|
25
26
|
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
|
26
27
|
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
|
27
28
|
|
|
28
29
|
const uint32_t size_base = kv_size;
|
|
29
30
|
|
|
30
|
-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
|
31
|
+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
|
31
32
|
|
|
32
33
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
33
34
|
if (swa_full) {
|
|
@@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
|
41
42
|
|
|
42
43
|
kv_base = std::make_unique<llama_kv_cache_unified>(
|
|
43
44
|
model, std::move(filter_base), type_k, type_v,
|
|
44
|
-
v_trans, offload, size_base, n_seq_max, n_pad,
|
|
45
|
+
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
|
45
46
|
0, LLAMA_SWA_TYPE_NONE);
|
|
46
47
|
|
|
47
48
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
48
49
|
|
|
49
50
|
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
|
50
51
|
model, std::move(filter_swa), type_k, type_v,
|
|
51
|
-
v_trans, offload, size_swa, n_seq_max, n_pad,
|
|
52
|
+
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
|
52
53
|
hparams.n_swa, hparams.swa_type);
|
|
53
54
|
}
|
|
54
55
|
|
|
@@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
100
101
|
|
|
101
102
|
// first try simple split
|
|
102
103
|
do {
|
|
104
|
+
if (!unified) {
|
|
105
|
+
// requires equal splits, so we skip the simple split
|
|
106
|
+
break;
|
|
107
|
+
}
|
|
108
|
+
|
|
103
109
|
balloc.split_reset();
|
|
104
110
|
|
|
105
111
|
std::vector<llama_ubatch> ubatches;
|
|
@@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
|
140
146
|
|
|
141
147
|
std::vector<llama_ubatch> ubatches;
|
|
142
148
|
while (true) {
|
|
143
|
-
auto ubatch = balloc.split_equal(n_ubatch,
|
|
149
|
+
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
|
144
150
|
|
|
145
151
|
if (ubatch.n_tokens == 0) {
|
|
146
152
|
break;
|
|
@@ -20,6 +20,7 @@ public:
|
|
|
20
20
|
bool v_trans,
|
|
21
21
|
bool offload,
|
|
22
22
|
bool swa_full,
|
|
23
|
+
bool unified,
|
|
23
24
|
uint32_t kv_size,
|
|
24
25
|
uint32_t n_seq_max,
|
|
25
26
|
uint32_t n_ubatch,
|
|
@@ -68,6 +69,8 @@ public:
|
|
|
68
69
|
private:
|
|
69
70
|
const llama_hparams & hparams;
|
|
70
71
|
|
|
72
|
+
const bool unified;
|
|
73
|
+
|
|
71
74
|
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
|
72
75
|
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
|
73
76
|
};
|