@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include "llama.h"
|
|
4
|
-
#include "llama-batch.h"
|
|
5
4
|
#include "llama-cparams.h"
|
|
6
5
|
#include "llama-graph.h"
|
|
7
6
|
#include "llama-adapter.h"
|
|
@@ -13,13 +12,13 @@
|
|
|
13
12
|
#include <vector>
|
|
14
13
|
|
|
15
14
|
struct llama_model;
|
|
16
|
-
|
|
15
|
+
class llama_batch_allocr;
|
|
17
16
|
|
|
18
17
|
class llama_io_read_i;
|
|
19
18
|
class llama_io_write_i;
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
struct llama_memory_i;
|
|
21
|
+
struct llama_memory_state_i;
|
|
23
22
|
|
|
24
23
|
struct llama_context {
|
|
25
24
|
// init scheduler and compute buffers, reserve worst-case graphs
|
|
@@ -47,12 +46,12 @@ struct llama_context {
|
|
|
47
46
|
uint32_t n_threads() const;
|
|
48
47
|
uint32_t n_threads_batch() const;
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
const llama_kv_cache * get_kv_self() const;
|
|
49
|
+
llama_memory_t get_memory() const;
|
|
52
50
|
|
|
53
51
|
// return true of the KV cache was updated
|
|
54
52
|
// TODO: remove
|
|
55
|
-
bool kv_self_update();
|
|
53
|
+
bool kv_self_update(bool optimize);
|
|
54
|
+
void kv_self_defrag_sched();
|
|
56
55
|
|
|
57
56
|
enum llama_pooling_type pooling_type() const;
|
|
58
57
|
|
|
@@ -103,8 +102,8 @@ struct llama_context {
|
|
|
103
102
|
llama_memory_state_i * mstate,
|
|
104
103
|
ggml_status & ret);
|
|
105
104
|
|
|
106
|
-
int encode(llama_batch &
|
|
107
|
-
int decode(llama_batch &
|
|
105
|
+
int encode(const llama_batch & batch_inp);
|
|
106
|
+
int decode(const llama_batch & batch_inp);
|
|
108
107
|
|
|
109
108
|
//
|
|
110
109
|
// state save/load
|
|
@@ -182,7 +181,7 @@ private:
|
|
|
182
181
|
|
|
183
182
|
// Make sure enough space is available for outputs.
|
|
184
183
|
// Returns max number of outputs for which space was reserved.
|
|
185
|
-
|
|
184
|
+
uint32_t output_reserve(int32_t n_outputs);
|
|
186
185
|
|
|
187
186
|
//
|
|
188
187
|
// graph
|
|
@@ -231,6 +230,9 @@ private:
|
|
|
231
230
|
|
|
232
231
|
std::unique_ptr<llama_memory_i> memory;
|
|
233
232
|
|
|
233
|
+
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
|
234
|
+
bool memory_force_optimize = false;
|
|
235
|
+
|
|
234
236
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
235
237
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
236
238
|
float * logits = nullptr;
|
|
@@ -244,8 +246,10 @@ private:
|
|
|
244
246
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
|
245
247
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
|
246
248
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
+
// reuse the batch_allocr to avoid unnecessary memory allocations
|
|
250
|
+
std::unique_ptr<llama_batch_allocr> batch_allocr;
|
|
251
|
+
|
|
252
|
+
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
|
249
253
|
|
|
250
254
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
|
251
255
|
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
|
|
7
7
|
#include "llama-kv-cache-unified.h"
|
|
8
8
|
#include "llama-kv-cache-unified-iswa.h"
|
|
9
|
-
#include "llama-
|
|
9
|
+
#include "llama-memory-hybrid.h"
|
|
10
|
+
#include "llama-memory-recurrent.h"
|
|
10
11
|
|
|
11
12
|
#include <cassert>
|
|
12
13
|
#include <cmath>
|
|
@@ -139,6 +140,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
139
140
|
|
|
140
141
|
std::vector<uint64_t> sum(n_tokens, 0);
|
|
141
142
|
|
|
143
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
142
144
|
for (int s = 0; s < n_seqs; ++s) {
|
|
143
145
|
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
144
146
|
|
|
@@ -156,6 +158,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
156
158
|
}
|
|
157
159
|
}
|
|
158
160
|
|
|
161
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
159
162
|
for (int s = 0; s < n_seqs; ++s) {
|
|
160
163
|
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
161
164
|
|
|
@@ -180,6 +183,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
180
183
|
uint32_t * data = (uint32_t *) cls->data;
|
|
181
184
|
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
|
182
185
|
|
|
186
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
183
187
|
for (int s = 0; s < n_seqs; ++s) {
|
|
184
188
|
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
185
189
|
|
|
@@ -210,6 +214,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
210
214
|
std::vector<int> last_pos(n_tokens, -1);
|
|
211
215
|
std::vector<int> last_row(n_tokens, -1);
|
|
212
216
|
|
|
217
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
213
218
|
for (int s = 0; s < n_seqs; ++s) {
|
|
214
219
|
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
215
220
|
|
|
@@ -234,34 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
234
239
|
}
|
|
235
240
|
}
|
|
236
241
|
|
|
237
|
-
void
|
|
242
|
+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
238
243
|
GGML_UNUSED(ubatch);
|
|
239
244
|
|
|
240
|
-
const int64_t
|
|
245
|
+
const int64_t n_rs = mem_state->get_n_rs();
|
|
241
246
|
|
|
242
247
|
if (s_copy) {
|
|
243
248
|
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
244
249
|
int32_t * data = (int32_t *) s_copy->data;
|
|
245
250
|
|
|
246
251
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
247
|
-
for (uint32_t i = 0; i <
|
|
248
|
-
data[i] =
|
|
249
|
-
}
|
|
250
|
-
}
|
|
251
|
-
}
|
|
252
|
-
|
|
253
|
-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
254
|
-
GGML_UNUSED(ubatch);
|
|
255
|
-
|
|
256
|
-
const int64_t n_kv = kv_state->get_n_kv();
|
|
257
|
-
|
|
258
|
-
if (s_mask) {
|
|
259
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
|
|
260
|
-
float * data = (float *) s_mask->data;
|
|
261
|
-
|
|
262
|
-
// clear unused states
|
|
263
|
-
for (int i = 0; i < n_kv; ++i) {
|
|
264
|
-
data[i] = kv_state->s_mask(i);
|
|
252
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
253
|
+
data[i] = mem_state->s_copy(i);
|
|
265
254
|
}
|
|
266
255
|
}
|
|
267
256
|
}
|
|
@@ -299,6 +288,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
299
288
|
const int32_t ti = s0*n_seq_tokens + i;
|
|
300
289
|
float f = -INFINITY;
|
|
301
290
|
|
|
291
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
302
292
|
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
303
293
|
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
|
304
294
|
if (hparams.use_alibi) {
|
|
@@ -338,6 +328,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
338
328
|
const int32_t ti = s0*n_seq_tokens + i;
|
|
339
329
|
float f = -INFINITY;
|
|
340
330
|
|
|
331
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
341
332
|
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
342
333
|
if (ubatch->seq_id[s0][s] == seq_id) {
|
|
343
334
|
if (hparams.use_alibi) {
|
|
@@ -393,6 +384,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
393
384
|
for (int j = 0; j < n_tokens; ++j) {
|
|
394
385
|
for (int i = 0; i < n_enc; ++i) {
|
|
395
386
|
float f = -INFINITY;
|
|
387
|
+
// TODO: fix indexing [UBATCH_IDX]
|
|
396
388
|
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
|
397
389
|
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
|
398
390
|
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
|
@@ -412,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
412
404
|
}
|
|
413
405
|
}
|
|
414
406
|
|
|
407
|
+
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
408
|
+
if (self_kq_mask) {
|
|
409
|
+
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
|
413
|
+
|
|
414
|
+
if (s_copy) {
|
|
415
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
416
|
+
int32_t * data = (int32_t *) s_copy->data;
|
|
417
|
+
|
|
418
|
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
419
|
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
|
420
|
+
data[i] = mem_state->get_state_recr()->s_copy(i);
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
415
425
|
//
|
|
416
426
|
// llm_graph_context
|
|
417
427
|
//
|
|
@@ -650,6 +660,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
650
660
|
{
|
|
651
661
|
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
|
652
662
|
int64_t split_point = cur->ne[0] / 2;
|
|
663
|
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
653
664
|
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
654
665
|
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
655
666
|
|
|
@@ -659,6 +670,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
659
670
|
cur = ggml_mul(ctx0, x0, x1);
|
|
660
671
|
cb(cur, "ffn_mul", il);
|
|
661
672
|
} break;
|
|
673
|
+
case LLM_FFN_GEGLU:
|
|
674
|
+
{
|
|
675
|
+
// Split into two equal parts
|
|
676
|
+
int64_t split_point = cur->ne[0] / 2;
|
|
677
|
+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
678
|
+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
679
|
+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
680
|
+
|
|
681
|
+
x0 = ggml_gelu(ctx0, x0);
|
|
682
|
+
cb(x0, "ffn_gelu", il);
|
|
683
|
+
|
|
684
|
+
cur = ggml_mul(ctx0, x0, x1);
|
|
685
|
+
cb(cur, "ffn_geglu", il);
|
|
686
|
+
} break;
|
|
662
687
|
}
|
|
663
688
|
|
|
664
689
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -769,9 +794,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
769
794
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
|
770
795
|
|
|
771
796
|
if (weight_before_ffn) {
|
|
772
|
-
//
|
|
773
|
-
ggml_tensor * repeated =
|
|
774
|
-
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
|
797
|
+
// repeat cur to [n_embd, n_expert_used, n_tokens]
|
|
798
|
+
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
|
|
775
799
|
cur = ggml_mul(ctx0, repeated, weights);
|
|
776
800
|
cb(cur, "ffn_moe_weighted", il);
|
|
777
801
|
}
|
|
@@ -956,40 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
956
980
|
return cur;
|
|
957
981
|
}
|
|
958
982
|
|
|
959
|
-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
960
|
-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
961
|
-
|
|
962
|
-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
|
963
|
-
|
|
964
|
-
const auto n_kv = kv_state->get_n_kv();
|
|
965
|
-
|
|
966
|
-
auto & cur = inp->s_copy;
|
|
967
|
-
|
|
968
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
|
969
|
-
ggml_set_input(cur);
|
|
970
|
-
|
|
971
|
-
res->add_input(std::move(inp));
|
|
972
|
-
|
|
973
|
-
return cur;
|
|
974
|
-
}
|
|
975
|
-
|
|
976
|
-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
|
977
|
-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
|
978
|
-
|
|
979
|
-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
|
|
980
|
-
|
|
981
|
-
const auto n_kv = kv_state->get_n_kv();
|
|
982
|
-
|
|
983
|
-
auto & cur = inp->s_mask;
|
|
984
|
-
|
|
985
|
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
|
986
|
-
ggml_set_input(cur);
|
|
987
|
-
|
|
988
|
-
res->add_input(std::move(inp));
|
|
989
|
-
|
|
990
|
-
return cur;
|
|
991
|
-
}
|
|
992
|
-
|
|
993
983
|
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
|
994
984
|
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
|
995
985
|
|
|
@@ -1059,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
|
1059
1049
|
return pos_bias;
|
|
1060
1050
|
}
|
|
1061
1051
|
|
|
1052
|
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
1053
|
+
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
|
1054
|
+
|
|
1055
|
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
|
1056
|
+
|
|
1057
|
+
{
|
|
1058
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
|
1059
|
+
|
|
1060
|
+
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
|
1061
|
+
|
|
1062
|
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1063
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1064
|
+
ggml_set_input(inp->self_kq_mask);
|
|
1065
|
+
|
|
1066
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
{
|
|
1070
|
+
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
|
1071
|
+
|
|
1072
|
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1073
|
+
ggml_set_input(inp->s_copy);
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1062
1079
|
ggml_tensor * llm_graph_context::build_attn_mha(
|
|
1063
1080
|
ggml_cgraph * gf,
|
|
1064
1081
|
ggml_tensor * q,
|
|
@@ -1303,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1303
1320
|
return cur;
|
|
1304
1321
|
}
|
|
1305
1322
|
|
|
1306
|
-
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1307
|
-
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
1308
|
-
|
|
1309
|
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
|
1310
|
-
|
|
1311
|
-
{
|
|
1312
|
-
const auto n_kv = kv_state->get_base()->get_n_kv();
|
|
1313
|
-
|
|
1314
|
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1315
|
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1316
|
-
ggml_set_input(inp->self_kq_mask);
|
|
1317
|
-
|
|
1318
|
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1319
|
-
}
|
|
1320
|
-
|
|
1321
|
-
{
|
|
1322
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1323
|
-
|
|
1324
|
-
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
|
1325
|
-
|
|
1326
|
-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1327
|
-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
1328
|
-
ggml_set_input(inp->self_kq_mask_swa);
|
|
1329
|
-
|
|
1330
|
-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1331
|
-
}
|
|
1332
|
-
|
|
1333
|
-
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
1334
|
-
}
|
|
1335
|
-
|
|
1336
1323
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1337
1324
|
llm_graph_input_attn_kv_unified_iswa * inp,
|
|
1338
1325
|
ggml_cgraph * gf,
|
|
@@ -1442,56 +1429,182 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1442
1429
|
return cur;
|
|
1443
1430
|
}
|
|
1444
1431
|
|
|
1445
|
-
ggml_tensor * llm_graph_context::
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1432
|
+
ggml_tensor * llm_graph_context::build_attn(
|
|
1433
|
+
llm_graph_input_mem_hybrid * inp,
|
|
1434
|
+
ggml_cgraph * gf,
|
|
1435
|
+
ggml_tensor * wo,
|
|
1436
|
+
ggml_tensor * wo_b,
|
|
1437
|
+
ggml_tensor * q_cur,
|
|
1438
|
+
ggml_tensor * k_cur,
|
|
1439
|
+
ggml_tensor * v_cur,
|
|
1440
|
+
ggml_tensor * kq_b,
|
|
1441
|
+
ggml_tensor * v_mla,
|
|
1442
|
+
float kq_scale,
|
|
1443
|
+
int il) const {
|
|
1444
|
+
// these nodes are added to the graph together so that they are not reordered
|
|
1445
|
+
// by doing so, the number of splits in the graph is reduced
|
|
1446
|
+
ggml_build_forward_expand(gf, q_cur);
|
|
1447
|
+
ggml_build_forward_expand(gf, k_cur);
|
|
1448
|
+
ggml_build_forward_expand(gf, v_cur);
|
|
1449
|
+
|
|
1450
|
+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
|
1451
|
+
|
|
1452
|
+
// store to KV cache
|
|
1453
|
+
{
|
|
1454
|
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
|
1455
|
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
|
1456
|
+
}
|
|
1457
|
+
|
|
1458
|
+
const auto & kq_mask = inp->get_kq_mask();
|
|
1459
|
+
|
|
1460
|
+
ggml_tensor * q = q_cur;
|
|
1461
|
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
|
1462
|
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
|
1463
|
+
|
|
1464
|
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1465
|
+
cb(cur, "kqv_out", il);
|
|
1466
|
+
|
|
1467
|
+
if (wo) {
|
|
1468
|
+
cur = build_lora_mm(wo, cur);
|
|
1469
|
+
if (arch == LLM_ARCH_GLM4) {
|
|
1470
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
1471
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1472
|
+
}
|
|
1473
|
+
}
|
|
1474
|
+
|
|
1475
|
+
if (wo_b) {
|
|
1476
|
+
cur = ggml_add(ctx0, cur, wo_b);
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
return cur;
|
|
1480
|
+
}
|
|
1456
1481
|
|
|
1457
|
-
|
|
1482
|
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
|
1483
|
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
|
1484
|
+
|
|
1485
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
|
1486
|
+
|
|
1487
|
+
{
|
|
1488
|
+
const auto n_kv = kv_state->get_base()->get_n_kv();
|
|
1489
|
+
|
|
1490
|
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1491
|
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
1492
|
+
ggml_set_input(inp->self_kq_mask);
|
|
1493
|
+
|
|
1494
|
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
{
|
|
1498
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
|
1499
|
+
|
|
1500
|
+
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
|
1501
|
+
|
|
1502
|
+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
|
1503
|
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
1504
|
+
ggml_set_input(inp->self_kq_mask_swa);
|
|
1505
|
+
|
|
1506
|
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1507
|
+
}
|
|
1458
1508
|
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
// this shrinks the tensors's ne[1] to n_kv
|
|
1462
|
-
states = ggml_get_rows(ctx0, states, state_copy);
|
|
1509
|
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
|
1510
|
+
}
|
|
1463
1511
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1512
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1513
|
+
ggml_cgraph * gf,
|
|
1514
|
+
ggml_tensor * s,
|
|
1515
|
+
ggml_tensor * state_copy,
|
|
1516
|
+
int32_t state_size,
|
|
1517
|
+
int32_t n_seqs,
|
|
1518
|
+
uint32_t n_kv,
|
|
1519
|
+
uint32_t kv_head,
|
|
1520
|
+
uint32_t kv_size,
|
|
1521
|
+
int32_t rs_zero,
|
|
1522
|
+
bool avoid_copies) const {
|
|
1523
|
+
|
|
1524
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
|
1525
|
+
|
|
1526
|
+
// Clear a single state which will then be copied to the other cleared states.
|
|
1527
|
+
// Note that this is a no-op when the view is zero-sized.
|
|
1528
|
+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
|
1529
|
+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1530
|
+
|
|
1531
|
+
ggml_tensor * output_states;
|
|
1532
|
+
|
|
1533
|
+
if (!avoid_copies) {
|
|
1534
|
+
// copy states
|
|
1535
|
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
|
1536
|
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
1537
|
+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
1538
|
+
ggml_build_forward_expand(gf, output_states);
|
|
1539
|
+
} else {
|
|
1540
|
+
// FIXME: make the gathering operation happen before the copy below
|
|
1541
|
+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
1542
|
+
output_states = states;
|
|
1543
|
+
}
|
|
1467
1544
|
|
|
1468
|
-
// copy states which won't be changed further (between n_seqs and n_kv)
|
|
1545
|
+
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
|
1546
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
1469
1547
|
ggml_build_forward_expand(gf,
|
|
1470
1548
|
ggml_cpy(ctx0,
|
|
1471
|
-
|
|
1472
|
-
ggml_view_1d(ctx0, s,
|
|
1549
|
+
states_extra,
|
|
1550
|
+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
|
|
1473
1551
|
|
|
1474
|
-
|
|
1475
|
-
|
|
1552
|
+
return output_states;
|
|
1553
|
+
}
|
|
1554
|
+
|
|
1555
|
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
1556
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
|
1557
|
+
|
|
1558
|
+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
|
1559
|
+
|
|
1560
|
+
const auto n_rs = kv_state->get_n_rs();
|
|
1561
|
+
|
|
1562
|
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1563
|
+
ggml_set_input(inp->s_copy);
|
|
1564
|
+
|
|
1565
|
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
|
1566
|
+
}
|
|
1567
|
+
|
|
1568
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1569
|
+
llm_graph_input_rs * inp,
|
|
1570
|
+
ggml_cgraph * gf,
|
|
1571
|
+
ggml_tensor * s,
|
|
1572
|
+
int32_t state_size,
|
|
1573
|
+
int32_t n_seqs,
|
|
1574
|
+
bool avoid_copies) const {
|
|
1575
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
|
1576
|
+
|
|
1577
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
|
1578
|
+
}
|
|
1579
|
+
|
|
1580
|
+
ggml_tensor * llm_graph_context::build_rs(
|
|
1581
|
+
llm_graph_input_mem_hybrid * inp,
|
|
1582
|
+
ggml_cgraph * gf,
|
|
1583
|
+
ggml_tensor * s,
|
|
1584
|
+
int32_t state_size,
|
|
1585
|
+
int32_t n_seqs,
|
|
1586
|
+
bool avoid_copies) const {
|
|
1587
|
+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
|
1588
|
+
|
|
1589
|
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
|
1476
1590
|
}
|
|
1477
1591
|
|
|
1478
1592
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
const llama_ubatch & ubatch,
|
|
1593
|
+
llm_graph_input_rs * inp,
|
|
1594
|
+
ggml_cgraph * gf,
|
|
1595
|
+
const llama_ubatch & ubatch,
|
|
1483
1596
|
int il) const {
|
|
1484
|
-
const auto * kv_state = static_cast<const
|
|
1597
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
|
1485
1598
|
|
|
1486
1599
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1487
1600
|
|
|
1488
1601
|
const int64_t n_seqs = ubatch.n_seqs;
|
|
1489
1602
|
|
|
1490
|
-
ggml_tensor * token_shift_all = kv_state->
|
|
1603
|
+
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
|
1491
1604
|
|
|
1492
|
-
ggml_tensor * token_shift =
|
|
1493
|
-
gf, token_shift_all,
|
|
1494
|
-
hparams.
|
|
1605
|
+
ggml_tensor * token_shift = build_rs(
|
|
1606
|
+
inp, gf, token_shift_all,
|
|
1607
|
+
hparams.n_embd_r(), n_seqs);
|
|
1495
1608
|
|
|
1496
1609
|
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
|
1497
1610
|
|
|
@@ -1502,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1502
1615
|
ggml_tensor * token_shift,
|
|
1503
1616
|
const llama_ubatch & ubatch,
|
|
1504
1617
|
int il) const {
|
|
1505
|
-
const auto * kv_state = static_cast<const
|
|
1618
|
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
|
1506
1619
|
|
|
1507
1620
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1508
1621
|
const auto n_embd = hparams.n_embd;
|
|
@@ -1514,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1514
1627
|
return ggml_cpy(
|
|
1515
1628
|
ctx0,
|
|
1516
1629
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
|
1517
|
-
ggml_view_1d(ctx0, kv_state->
|
|
1630
|
+
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
|
1518
1631
|
);
|
|
1519
1632
|
}
|
|
1520
1633
|
|
|
@@ -1565,23 +1678,30 @@ void llm_graph_context::build_pooling(
|
|
|
1565
1678
|
ggml_tensor * inp_cls = build_inp_cls();
|
|
1566
1679
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
|
1567
1680
|
|
|
1568
|
-
if (cls
|
|
1681
|
+
if (cls) {
|
|
1569
1682
|
// classification head
|
|
1570
1683
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
|
1571
|
-
cur =
|
|
1684
|
+
cur = ggml_mul_mat(ctx0, cls, inp);
|
|
1685
|
+
if (cls_b) {
|
|
1686
|
+
cur = ggml_add(ctx0, cur, cls_b);
|
|
1687
|
+
}
|
|
1572
1688
|
cur = ggml_tanh(ctx0, cur);
|
|
1573
1689
|
|
|
1574
1690
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
|
1575
1691
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
|
1576
1692
|
if (cls_out) {
|
|
1577
|
-
|
|
1578
|
-
|
|
1693
|
+
cur = ggml_mul_mat(ctx0, cls_out, cur);
|
|
1694
|
+
if (cls_out_b) {
|
|
1695
|
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
|
1696
|
+
}
|
|
1579
1697
|
}
|
|
1580
1698
|
} else if (cls_out) {
|
|
1581
1699
|
// Single layer classification head (direct projection)
|
|
1582
1700
|
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
|
1583
|
-
|
|
1584
|
-
|
|
1701
|
+
cur = ggml_mul_mat(ctx0, cls_out, inp);
|
|
1702
|
+
if (cls_out_b) {
|
|
1703
|
+
cur = ggml_add(ctx0, cur, cls_out_b);
|
|
1704
|
+
}
|
|
1585
1705
|
} else {
|
|
1586
1706
|
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
|
1587
1707
|
}
|