@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -17,11 +17,12 @@ struct ggml_tensor;
|
|
|
17
17
|
struct llama_ubatch;
|
|
18
18
|
struct llama_cparams;
|
|
19
19
|
|
|
20
|
-
|
|
20
|
+
struct llama_memory_state_i;
|
|
21
21
|
|
|
22
22
|
class llama_kv_cache_unified_state;
|
|
23
23
|
class llama_kv_cache_unified_iswa_state;
|
|
24
|
-
class
|
|
24
|
+
class llama_memory_recurrent_state;
|
|
25
|
+
class llama_memory_hybrid_state;
|
|
25
26
|
|
|
26
27
|
// certain models (typically multi-modal) can produce different types of graphs
|
|
27
28
|
enum llm_graph_type {
|
|
@@ -36,6 +37,7 @@ enum llm_ffn_op_type {
|
|
|
36
37
|
LLM_FFN_RELU,
|
|
37
38
|
LLM_FFN_RELU_SQR,
|
|
38
39
|
LLM_FFN_SWIGLU,
|
|
40
|
+
LLM_FFN_GEGLU,
|
|
39
41
|
};
|
|
40
42
|
|
|
41
43
|
enum llm_ffn_gate_type {
|
|
@@ -187,28 +189,16 @@ public:
|
|
|
187
189
|
const llama_cparams & cparams;
|
|
188
190
|
};
|
|
189
191
|
|
|
190
|
-
class
|
|
192
|
+
class llm_graph_input_rs : public llm_graph_input_i {
|
|
191
193
|
public:
|
|
192
|
-
|
|
193
|
-
virtual ~
|
|
194
|
+
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
|
195
|
+
virtual ~llm_graph_input_rs() = default;
|
|
194
196
|
|
|
195
197
|
void set_input(const llama_ubatch * ubatch) override;
|
|
196
198
|
|
|
197
199
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
198
200
|
|
|
199
|
-
const
|
|
200
|
-
};
|
|
201
|
-
|
|
202
|
-
class llm_graph_input_s_mask : public llm_graph_input_i {
|
|
203
|
-
public:
|
|
204
|
-
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
|
205
|
-
virtual ~llm_graph_input_s_mask() = default;
|
|
206
|
-
|
|
207
|
-
void set_input(const llama_ubatch * ubatch) override;
|
|
208
|
-
|
|
209
|
-
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
210
|
-
|
|
211
|
-
const llama_kv_cache_recurrent_state * kv_state;
|
|
201
|
+
const llama_memory_recurrent_state * mem_state;
|
|
212
202
|
};
|
|
213
203
|
|
|
214
204
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -311,6 +301,33 @@ public:
|
|
|
311
301
|
const llama_cross * cross = nullptr;
|
|
312
302
|
};
|
|
313
303
|
|
|
304
|
+
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
|
305
|
+
public:
|
|
306
|
+
llm_graph_input_mem_hybrid(
|
|
307
|
+
const llama_hparams & hparams,
|
|
308
|
+
const llama_cparams & cparams,
|
|
309
|
+
const llama_memory_hybrid_state * mem_state) :
|
|
310
|
+
hparams(hparams),
|
|
311
|
+
cparams(cparams),
|
|
312
|
+
mem_state(mem_state) {
|
|
313
|
+
}
|
|
314
|
+
virtual ~llm_graph_input_mem_hybrid() = default;
|
|
315
|
+
|
|
316
|
+
void set_input(const llama_ubatch * ubatch) override;
|
|
317
|
+
|
|
318
|
+
ggml_tensor * s_copy; // I32 [kv_size]
|
|
319
|
+
|
|
320
|
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
|
321
|
+
|
|
322
|
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
|
323
|
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
|
324
|
+
|
|
325
|
+
const llama_hparams & hparams;
|
|
326
|
+
const llama_cparams & cparams;
|
|
327
|
+
|
|
328
|
+
const llama_memory_hybrid_state * mem_state;
|
|
329
|
+
};
|
|
330
|
+
|
|
314
331
|
//
|
|
315
332
|
// llm_graph_result
|
|
316
333
|
//
|
|
@@ -389,7 +406,7 @@ struct llm_graph_params {
|
|
|
389
406
|
const llama_memory_state_i * mstate;
|
|
390
407
|
const llama_cross * cross;
|
|
391
408
|
|
|
392
|
-
|
|
409
|
+
uint32_t n_outputs;
|
|
393
410
|
|
|
394
411
|
const llm_graph_cb & cb;
|
|
395
412
|
};
|
|
@@ -423,8 +440,8 @@ struct llm_graph_context {
|
|
|
423
440
|
const float norm_eps;
|
|
424
441
|
const float norm_rms_eps;
|
|
425
442
|
|
|
426
|
-
const
|
|
427
|
-
const
|
|
443
|
+
const int64_t n_tokens;
|
|
444
|
+
const int64_t n_outputs;
|
|
428
445
|
const int32_t n_ctx_orig; // yarn
|
|
429
446
|
|
|
430
447
|
const enum llama_pooling_type pooling_type;
|
|
@@ -519,14 +536,14 @@ struct llm_graph_context {
|
|
|
519
536
|
ggml_tensor * build_inp_out_ids() const;
|
|
520
537
|
ggml_tensor * build_inp_mean() const;
|
|
521
538
|
ggml_tensor * build_inp_cls() const;
|
|
522
|
-
ggml_tensor * build_inp_s_copy() const;
|
|
523
|
-
ggml_tensor * build_inp_s_mask() const;
|
|
524
539
|
|
|
525
540
|
ggml_tensor * build_inp_cross_embd() const;
|
|
526
541
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
|
527
542
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
|
528
543
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
|
529
544
|
|
|
545
|
+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
|
546
|
+
|
|
530
547
|
//
|
|
531
548
|
// attention
|
|
532
549
|
//
|
|
@@ -601,23 +618,62 @@ struct llm_graph_context {
|
|
|
601
618
|
float kq_scale,
|
|
602
619
|
int il) const;
|
|
603
620
|
|
|
621
|
+
ggml_tensor * build_attn(
|
|
622
|
+
llm_graph_input_mem_hybrid * inp,
|
|
623
|
+
ggml_cgraph * gf,
|
|
624
|
+
ggml_tensor * wo,
|
|
625
|
+
ggml_tensor * wo_b,
|
|
626
|
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
|
627
|
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
|
628
|
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
|
629
|
+
ggml_tensor * kq_b,
|
|
630
|
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
|
631
|
+
float kq_scale,
|
|
632
|
+
int il) const;
|
|
604
633
|
//
|
|
605
634
|
// recurrent
|
|
606
635
|
//
|
|
607
636
|
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
637
|
+
// TODO: avoid notion of "kv"
|
|
638
|
+
// TODO: move this implementation to llama_memory_recurrent.
|
|
639
|
+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
|
640
|
+
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
|
641
|
+
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
|
642
|
+
// `llama_memory_recurrent`
|
|
643
|
+
ggml_tensor * build_rs(
|
|
644
|
+
ggml_cgraph * gf,
|
|
645
|
+
ggml_tensor * s,
|
|
646
|
+
ggml_tensor * state_copy,
|
|
647
|
+
int32_t state_size,
|
|
648
|
+
int32_t n_seqs,
|
|
649
|
+
uint32_t n_kv,
|
|
650
|
+
uint32_t kv_head,
|
|
651
|
+
uint32_t kv_size,
|
|
652
|
+
int32_t rs_zero,
|
|
653
|
+
bool avoid_copies = false) const;
|
|
654
|
+
|
|
655
|
+
llm_graph_input_rs * build_rs_inp() const;
|
|
656
|
+
|
|
657
|
+
ggml_tensor * build_rs(
|
|
658
|
+
llm_graph_input_rs * inp,
|
|
659
|
+
ggml_cgraph * gf,
|
|
660
|
+
ggml_tensor * s,
|
|
661
|
+
int32_t state_size,
|
|
662
|
+
int32_t n_seqs,
|
|
663
|
+
bool avoid_copies = false) const;
|
|
664
|
+
|
|
665
|
+
ggml_tensor * build_rs(
|
|
666
|
+
llm_graph_input_mem_hybrid * inp,
|
|
667
|
+
ggml_cgraph * gf,
|
|
668
|
+
ggml_tensor * s,
|
|
669
|
+
int32_t state_size,
|
|
670
|
+
int32_t n_seqs,
|
|
671
|
+
bool avoid_copies = false) const;
|
|
615
672
|
|
|
616
673
|
ggml_tensor * build_rwkv_token_shift_load(
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
const llama_ubatch & ubatch,
|
|
674
|
+
llm_graph_input_rs * inp,
|
|
675
|
+
ggml_cgraph * gf,
|
|
676
|
+
const llama_ubatch & ubatch,
|
|
621
677
|
int il) const;
|
|
622
678
|
|
|
623
679
|
ggml_tensor * build_rwkv_token_shift_store(
|
|
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
|
65
65
|
return n_embd_head_v * n_head_kv;
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
-
uint32_t llama_hparams::
|
|
68
|
+
uint32_t llama_hparams::n_embd_r() const {
|
|
69
69
|
if (wkv_head_size != 0) {
|
|
70
70
|
// for RWKV models
|
|
71
71
|
return token_shift_count * n_embd;
|
|
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|
|
76
76
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
-
uint32_t llama_hparams::
|
|
79
|
+
uint32_t llama_hparams::n_embd_s() const {
|
|
80
80
|
if (wkv_head_size != 0) {
|
|
81
81
|
// corresponds to RWKV's wkv_states size
|
|
82
82
|
return n_embd * wkv_head_size;
|
|
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
|
86
86
|
return ssm_d_state * ssm_d_inner;
|
|
87
87
|
}
|
|
88
88
|
|
|
89
|
+
bool llama_hparams::is_recurrent(uint32_t il) const {
|
|
90
|
+
return recurrent_layer_arr[il];
|
|
91
|
+
}
|
|
92
|
+
|
|
89
93
|
bool llama_hparams::is_swa(uint32_t il) const {
|
|
90
94
|
if (il < n_layer) {
|
|
91
95
|
return swa_layers[il];
|
|
@@ -115,6 +115,9 @@ struct llama_hparams {
|
|
|
115
115
|
uint32_t ssm_d_state = 0;
|
|
116
116
|
uint32_t ssm_dt_rank = 0;
|
|
117
117
|
|
|
118
|
+
// for hybrid state space models
|
|
119
|
+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
|
120
|
+
|
|
118
121
|
bool ssm_dt_b_c_rms = false;
|
|
119
122
|
|
|
120
123
|
float f_clamp_kqv = 0.0f;
|
|
@@ -181,10 +184,13 @@ struct llama_hparams {
|
|
|
181
184
|
|
|
182
185
|
// dimension of the rolling state embeddings
|
|
183
186
|
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
|
184
|
-
uint32_t
|
|
187
|
+
uint32_t n_embd_r() const;
|
|
185
188
|
|
|
186
189
|
// dimension of the recurrent state embeddings
|
|
187
|
-
uint32_t
|
|
190
|
+
uint32_t n_embd_s() const;
|
|
191
|
+
|
|
192
|
+
// whether or not the given layer is recurrent (for hybrid models)
|
|
193
|
+
bool is_recurrent(uint32_t il) const;
|
|
188
194
|
|
|
189
195
|
bool is_swa(uint32_t il) const;
|
|
190
196
|
};
|
|
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|
|
52
52
|
hparams.n_swa, hparams.swa_type);
|
|
53
53
|
}
|
|
54
54
|
|
|
55
|
-
void llama_kv_cache_unified_iswa::clear() {
|
|
56
|
-
kv_base->clear();
|
|
57
|
-
kv_swa ->clear();
|
|
55
|
+
void llama_kv_cache_unified_iswa::clear(bool data) {
|
|
56
|
+
kv_base->clear(data);
|
|
57
|
+
kv_swa ->clear(data);
|
|
58
58
|
}
|
|
59
59
|
|
|
60
60
|
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
@@ -95,54 +95,77 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
95
95
|
return kv_swa->seq_pos_max(seq_id);
|
|
96
96
|
}
|
|
97
97
|
|
|
98
|
-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool
|
|
99
|
-
GGML_UNUSED(
|
|
98
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
|
99
|
+
GGML_UNUSED(embd_all);
|
|
100
100
|
|
|
101
|
-
//
|
|
102
|
-
|
|
101
|
+
// first try simple split
|
|
102
|
+
do {
|
|
103
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
|
103
104
|
|
|
104
|
-
|
|
105
|
+
std::vector<llama_ubatch> ubatches;
|
|
105
106
|
|
|
106
|
-
|
|
107
|
+
while (sbatch.n_tokens > 0) {
|
|
108
|
+
auto ubatch = sbatch.split_simple(n_ubatch);
|
|
107
109
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
+
ubatches.push_back(ubatch);
|
|
111
|
+
}
|
|
110
112
|
|
|
111
|
-
ubatches
|
|
112
|
-
|
|
113
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
114
|
+
if (heads_base.empty()) {
|
|
115
|
+
break;
|
|
116
|
+
}
|
|
113
117
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
119
|
+
if (heads_swa.empty()) {
|
|
120
|
+
break;
|
|
121
|
+
}
|
|
118
122
|
|
|
119
|
-
|
|
120
|
-
if (heads_swa.empty()) {
|
|
121
|
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
122
|
-
}
|
|
123
|
+
assert(heads_base.size() == heads_swa.size());
|
|
123
124
|
|
|
124
|
-
|
|
125
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
126
|
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
127
|
+
} while (false);
|
|
125
128
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
+
// if it fails, try equal split
|
|
130
|
+
do {
|
|
131
|
+
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
|
129
132
|
|
|
130
|
-
|
|
131
|
-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
|
132
|
-
}
|
|
133
|
+
std::vector<llama_ubatch> ubatches;
|
|
133
134
|
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
while (sbatch.n_tokens > 0) {
|
|
136
|
+
auto ubatch = sbatch.split_equal(n_ubatch);
|
|
136
137
|
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
ubatches.push_back(ubatch);
|
|
139
|
+
}
|
|
139
140
|
|
|
140
|
-
|
|
141
|
+
auto heads_base = kv_base->prepare(ubatches);
|
|
142
|
+
if (heads_base.empty()) {
|
|
143
|
+
break;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
auto heads_swa = kv_swa->prepare(ubatches);
|
|
147
|
+
if (heads_swa.empty()) {
|
|
148
|
+
break;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
assert(heads_base.size() == heads_swa.size());
|
|
152
|
+
|
|
153
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
|
154
|
+
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
|
155
|
+
} while (false);
|
|
156
|
+
|
|
157
|
+
// TODO: if we fail again, we should attempt different splitting strategies
|
|
158
|
+
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
159
|
+
|
|
160
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
|
164
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
|
141
165
|
}
|
|
142
166
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
kv_swa ->defrag_sched(thold);
|
|
167
|
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
168
|
+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
|
146
169
|
}
|
|
147
170
|
|
|
148
171
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
@@ -174,26 +197,34 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
|
174
197
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
|
175
198
|
|
|
176
199
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
200
|
+
llama_kv_cache_unified_iswa * kv) :
|
|
201
|
+
state_base(kv->get_base()->init_full()),
|
|
202
|
+
state_swa (kv->get_swa ()->init_full()),
|
|
203
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
207
|
+
llama_kv_cache_unified_iswa * kv,
|
|
208
|
+
llama_context * lctx,
|
|
209
|
+
bool optimize) :
|
|
210
|
+
state_base(kv->get_base()->init_update(lctx, optimize)),
|
|
211
|
+
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
212
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
181
213
|
}
|
|
182
214
|
|
|
183
215
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|
184
|
-
llama_memory_status status,
|
|
185
216
|
llama_kv_cache_unified_iswa * kv,
|
|
186
217
|
llama_sbatch sbatch,
|
|
187
218
|
std::vector<uint32_t> heads_base,
|
|
188
219
|
std::vector<uint32_t> heads_swa,
|
|
189
|
-
std::vector<llama_ubatch> ubatches)
|
|
190
|
-
: status(status),
|
|
220
|
+
std::vector<llama_ubatch> ubatches) :
|
|
191
221
|
sbatch(std::move(sbatch)),
|
|
192
|
-
ubatches(std::move(ubatches))
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
222
|
+
ubatches(std::move(ubatches)),
|
|
223
|
+
// note: here we copy the ubatches. not sure if this is ideal
|
|
224
|
+
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
|
225
|
+
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
|
226
|
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
227
|
+
}
|
|
197
228
|
|
|
198
229
|
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
199
230
|
|
|
@@ -233,17 +264,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|
|
233
264
|
|
|
234
265
|
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
|
235
266
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
267
|
+
|
|
236
268
|
return ubatches[i_next];
|
|
237
269
|
}
|
|
238
270
|
|
|
239
271
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
|
240
272
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
241
273
|
|
|
242
|
-
return state_base.get();
|
|
274
|
+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
|
243
275
|
}
|
|
244
276
|
|
|
245
277
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
|
246
278
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
247
279
|
|
|
248
|
-
return state_swa.get();
|
|
280
|
+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
|
249
281
|
}
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
// utilizes two instances of llama_kv_cache_unified
|
|
12
12
|
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
|
13
13
|
|
|
14
|
-
class llama_kv_cache_unified_iswa : public
|
|
14
|
+
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
|
15
15
|
public:
|
|
16
16
|
llama_kv_cache_unified_iswa(
|
|
17
17
|
const llama_model & model,
|
|
@@ -31,7 +31,18 @@ public:
|
|
|
31
31
|
// llama_memory_i
|
|
32
32
|
//
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
llama_memory_state_ptr init_batch(
|
|
35
|
+
const llama_batch & batch,
|
|
36
|
+
uint32_t n_ubatch,
|
|
37
|
+
bool embd_all) override;
|
|
38
|
+
|
|
39
|
+
llama_memory_state_ptr init_full() override;
|
|
40
|
+
|
|
41
|
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
|
42
|
+
|
|
43
|
+
bool get_can_shift() const override;
|
|
44
|
+
|
|
45
|
+
void clear(bool data) override;
|
|
35
46
|
|
|
36
47
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
|
37
48
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
|
@@ -42,24 +53,6 @@ public:
|
|
|
42
53
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
|
43
54
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
|
44
55
|
|
|
45
|
-
//
|
|
46
|
-
// llama_kv_cache
|
|
47
|
-
//
|
|
48
|
-
|
|
49
|
-
llama_memory_state_ptr init_batch(
|
|
50
|
-
const llama_batch & batch,
|
|
51
|
-
uint32_t n_ubatch,
|
|
52
|
-
bool embd_pooled,
|
|
53
|
-
bool logits_all) override;
|
|
54
|
-
|
|
55
|
-
llama_memory_state_ptr init_full() override;
|
|
56
|
-
|
|
57
|
-
bool update(llama_context & lctx) override;
|
|
58
|
-
|
|
59
|
-
void defrag_sched(float thold) override;
|
|
60
|
-
|
|
61
|
-
bool get_can_shift() const override;
|
|
62
|
-
|
|
63
56
|
// state write/load
|
|
64
57
|
|
|
65
58
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
|
@@ -86,12 +79,16 @@ public:
|
|
|
86
79
|
|
|
87
80
|
// used to create a full-cache state
|
|
88
81
|
llama_kv_cache_unified_iswa_state(
|
|
89
|
-
llama_memory_status status,
|
|
90
82
|
llama_kv_cache_unified_iswa * kv);
|
|
91
83
|
|
|
84
|
+
// used to create an update state
|
|
85
|
+
llama_kv_cache_unified_iswa_state(
|
|
86
|
+
llama_kv_cache_unified_iswa * kv,
|
|
87
|
+
llama_context * lctx,
|
|
88
|
+
bool optimize);
|
|
89
|
+
|
|
92
90
|
// used to create a state from a batch
|
|
93
91
|
llama_kv_cache_unified_iswa_state(
|
|
94
|
-
llama_memory_status status,
|
|
95
92
|
llama_kv_cache_unified_iswa * kv,
|
|
96
93
|
llama_sbatch sbatch,
|
|
97
94
|
std::vector<uint32_t> heads_base,
|
|
@@ -120,8 +117,6 @@ public:
|
|
|
120
117
|
const llama_kv_cache_unified_state * get_swa() const;
|
|
121
118
|
|
|
122
119
|
private:
|
|
123
|
-
const llama_memory_status status;
|
|
124
|
-
|
|
125
120
|
//llama_kv_cache_unified_iswa * kv;
|
|
126
121
|
|
|
127
122
|
llama_sbatch sbatch;
|
|
@@ -131,6 +126,8 @@ private:
|
|
|
131
126
|
|
|
132
127
|
std::vector<llama_ubatch> ubatches;
|
|
133
128
|
|
|
134
|
-
|
|
135
|
-
|
|
129
|
+
const llama_memory_state_ptr state_base;
|
|
130
|
+
const llama_memory_state_ptr state_swa;
|
|
131
|
+
|
|
132
|
+
const llama_memory_status status;
|
|
136
133
|
};
|