whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,313 +1,782 @@
|
|
1
1
|
#include "llama-batch.h"
|
2
2
|
|
3
|
+
#include "llama-impl.h"
|
4
|
+
#include "llama-vocab.h"
|
5
|
+
#include "llama-memory.h"
|
6
|
+
|
3
7
|
#include <cassert>
|
4
8
|
#include <cstring>
|
5
9
|
#include <algorithm>
|
10
|
+
#include <sstream>
|
6
11
|
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
break;
|
16
|
-
}
|
12
|
+
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
|
13
|
+
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
14
|
+
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
15
|
+
|
16
|
+
seq_pos.resize(LLAMA_MAX_SEQ);
|
17
|
+
seq_cpl.resize(LLAMA_MAX_SEQ);
|
18
|
+
for (auto & cur : seq_cpl) {
|
19
|
+
cur.resize(LLAMA_MAX_SEQ);
|
17
20
|
}
|
18
|
-
|
19
|
-
|
20
|
-
ubatch_pos.resize(n_ubatch);
|
21
|
-
ubatch_n_seq_id.resize(n_ubatch);
|
22
|
-
ubatch_seq_id.resize(n_ubatch);
|
23
|
-
ubatch_output.resize(n_ubatch);
|
24
|
-
llama_ubatch ubatch = {
|
25
|
-
/*equal_seqs =*/ true,
|
26
|
-
/*n_tokens =*/ 0,
|
27
|
-
/*n_seq_tokens =*/ 0,
|
28
|
-
/*n_seqs =*/ 0,
|
29
|
-
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
30
|
-
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
31
|
-
/*pos =*/ ubatch_pos.data(),
|
32
|
-
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
33
|
-
/*seq_id =*/ ubatch_seq_id.data(),
|
34
|
-
/*output =*/ ubatch_output.data(),
|
35
|
-
};
|
36
|
-
return ubatch;
|
21
|
+
|
22
|
+
seq_idx.resize(LLAMA_MAX_SEQ, -1);
|
37
23
|
}
|
38
24
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
25
|
+
bool llama_batch_allocr::init(
|
26
|
+
const llama_batch & batch_inp,
|
27
|
+
const llama_vocab & vocab,
|
28
|
+
const llama_memory_i * memory,
|
29
|
+
uint32_t n_embd,
|
30
|
+
bool output_all) {
|
31
|
+
clear();
|
32
|
+
|
33
|
+
batch = batch_inp;
|
34
|
+
|
35
|
+
this->vocab = &vocab;
|
36
|
+
|
37
|
+
GGML_ASSERT(batch.n_tokens > 0);
|
38
|
+
|
39
|
+
//
|
40
|
+
// validate input batch
|
41
|
+
//
|
42
|
+
|
43
|
+
if (batch.token) {
|
44
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
45
|
+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
46
|
+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
47
|
+
return false;
|
51
48
|
}
|
52
|
-
} else {
|
53
|
-
// simple split
|
54
|
-
ubatch.token = batch->token + seq.offset;
|
55
49
|
}
|
56
|
-
}
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
n_embd * sizeof(float)
|
66
|
-
);
|
50
|
+
}
|
51
|
+
|
52
|
+
if (batch.seq_id) {
|
53
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
54
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
55
|
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
56
|
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
57
|
+
return false;
|
58
|
+
}
|
67
59
|
}
|
68
|
-
} else {
|
69
|
-
// simple split
|
70
|
-
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
71
60
|
}
|
72
|
-
} else {
|
73
|
-
ubatch.embd = nullptr;
|
74
61
|
}
|
75
|
-
|
76
|
-
|
77
|
-
|
62
|
+
|
63
|
+
//
|
64
|
+
// auto-generate missing fields
|
65
|
+
//
|
66
|
+
|
67
|
+
if (!batch.n_seq_id) {
|
68
|
+
n_seq_id.resize(batch.n_tokens);
|
69
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
70
|
+
n_seq_id[i] = seq_id_0.size();
|
78
71
|
}
|
79
|
-
|
80
|
-
// simple split
|
81
|
-
ubatch.pos = batch->pos + seq.offset;
|
72
|
+
batch.n_seq_id = n_seq_id.data();
|
82
73
|
}
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
74
|
+
|
75
|
+
if (!batch.seq_id) {
|
76
|
+
seq_id.resize(batch.n_tokens + 1);
|
77
|
+
seq_id[batch.n_tokens] = NULL;
|
78
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
79
|
+
seq_id[i] = seq_id_0.data();
|
87
80
|
}
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
81
|
+
batch.seq_id = seq_id.data();
|
82
|
+
}
|
83
|
+
|
84
|
+
if (!batch.pos) {
|
85
|
+
pos.resize(batch.n_tokens);
|
86
|
+
|
87
|
+
// initialize the starting position for each sequence based on the positions in the memory
|
88
|
+
llama_pos p0[LLAMA_MAX_SEQ];
|
89
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
90
|
+
if (!memory) {
|
91
|
+
// if no memory -> start from 0
|
92
|
+
p0[s] = 0;
|
93
|
+
} else {
|
94
|
+
p0[s] = memory->seq_pos_max(s) + 1;
|
95
95
|
}
|
96
96
|
}
|
97
|
-
|
98
|
-
|
97
|
+
|
98
|
+
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
99
|
+
const llama_seq_id seq_id = batch.seq_id[i][0];
|
100
|
+
|
101
|
+
pos[i] = p0[seq_id];
|
102
|
+
|
103
|
+
// update the starting position for all sequences that are assigned to the this token
|
104
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
105
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
106
|
+
|
107
|
+
p0[seq_id] = pos[i] + 1;
|
108
|
+
}
|
99
109
|
}
|
110
|
+
|
111
|
+
batch.pos = pos.data();
|
100
112
|
}
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
113
|
+
|
114
|
+
if (!batch.logits) {
|
115
|
+
if (output_all) {
|
116
|
+
// return the output for all tokens
|
117
|
+
output.resize(batch.n_tokens, true);
|
118
|
+
} else {
|
119
|
+
// return the output only for the last token
|
120
|
+
output.resize(batch.n_tokens, false);
|
121
|
+
output[output.size() - 1] = true;
|
105
122
|
}
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
123
|
+
|
124
|
+
batch.logits = output.data();
|
125
|
+
} else if (output_all) {
|
126
|
+
bool warn = false;
|
127
|
+
|
128
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
129
|
+
if (batch.logits[i] == 0) {
|
130
|
+
warn = true;
|
113
131
|
}
|
114
|
-
}
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
132
|
+
}
|
133
|
+
|
134
|
+
if (warn) {
|
135
|
+
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
|
136
|
+
|
137
|
+
output.resize(batch.n_tokens, true);
|
138
|
+
batch.logits = output.data();
|
139
|
+
}
|
140
|
+
}
|
141
|
+
|
142
|
+
//
|
143
|
+
// compute stats
|
144
|
+
//
|
145
|
+
|
146
|
+
this->n_embd = n_embd;
|
147
|
+
|
148
|
+
// count the outputs in this batch
|
149
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
150
|
+
n_outputs += batch.logits[i] != 0;
|
151
|
+
}
|
152
|
+
|
153
|
+
// determine coupled sequences
|
154
|
+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
155
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
156
|
+
const llama_seq_id s0 = batch.seq_id[i][0];
|
157
|
+
|
158
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
159
|
+
const llama_seq_id s1 = batch.seq_id[i][s];
|
160
|
+
|
161
|
+
seq_pos[s1].insert(batch.pos[i]);
|
162
|
+
|
163
|
+
if (s > 0) {
|
164
|
+
// mark that sequence s1 is coupled to s0
|
165
|
+
seq_cpl[s1][s0] = true;
|
166
|
+
|
167
|
+
// note: tracking the other way around is not necessary for now
|
168
|
+
//seq_cpl[s0][s1] = true;
|
119
169
|
}
|
120
170
|
}
|
121
|
-
}
|
122
|
-
// only get last output
|
123
|
-
for (size_t i = 0; i < length; ++i) {
|
124
|
-
size_t id = ids[seq.offset + i];
|
125
|
-
int8_t is_last = id == ids.size() - 1;
|
126
|
-
ubatch.output[ubatch.n_tokens + i] = is_last;
|
127
|
-
if (is_last) { out_ids.push_back(id); }
|
128
|
-
}
|
129
|
-
}
|
130
|
-
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
131
|
-
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
132
|
-
}
|
133
|
-
ubatch.n_tokens += length;
|
134
|
-
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
135
|
-
seq.offset += length;
|
136
|
-
seq.length -= length;
|
137
|
-
n_tokens -= length;
|
138
|
-
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
139
|
-
}
|
171
|
+
}
|
140
172
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
add_seq_to_ubatch(ubatch, s, length);
|
150
|
-
}
|
151
|
-
return ubatch;
|
152
|
-
}
|
173
|
+
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
|
174
|
+
{
|
175
|
+
seq_set_t seq_set_unq;
|
176
|
+
|
177
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
178
|
+
seq_set_t cur;
|
179
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
180
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
153
181
|
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
if (length == 0) {
|
167
|
-
length = s.length < n_ubatch ? s.length : n_ubatch;
|
182
|
+
cur .set(seq_id);
|
183
|
+
seq_set_unq.set(seq_id);
|
184
|
+
}
|
185
|
+
|
186
|
+
seq_set.push_back(cur);
|
187
|
+
seq_set_map[cur].push_back(i);
|
188
|
+
}
|
189
|
+
|
190
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
191
|
+
if (seq_set_unq.test(s)) {
|
192
|
+
seq_idx[s] = seq_id_unq.size();
|
193
|
+
seq_id_unq.push_back(s);
|
168
194
|
}
|
169
|
-
add_seq_to_ubatch(ubatch, s, length);
|
170
|
-
n_tokens_in_ubatch += length;
|
171
|
-
// shared prompts can't be mixed with any of their sequences,
|
172
|
-
// so it's safer to compute them in their own ubatch
|
173
|
-
if (s.n_seq_id > 1) { break; }
|
174
|
-
// stop when there isn't enough space for another sequence
|
175
|
-
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
176
195
|
}
|
177
196
|
}
|
178
|
-
return ubatch;
|
179
|
-
}
|
180
197
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
198
|
+
if (debug > 0) {
|
199
|
+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
200
|
+
|
201
|
+
llama_ubatch ubatch {
|
202
|
+
/*.equal_seqs =*/ false,
|
203
|
+
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
204
|
+
/*.n_seq_tokens =*/ (uint32_t) 1,
|
205
|
+
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
206
|
+
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
207
|
+
/*.token =*/ batch.token,
|
208
|
+
/*.embd =*/ batch.embd,
|
209
|
+
/*.pos =*/ batch.pos,
|
210
|
+
/*.n_seq_id =*/ batch.n_seq_id,
|
211
|
+
/*.seq_id =*/ batch.seq_id,
|
212
|
+
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
213
|
+
/*.seq_idx =*/ this->seq_idx.data(),
|
214
|
+
/*.output =*/ batch.logits,
|
215
|
+
};
|
216
|
+
|
217
|
+
ubatch_print(ubatch, debug);
|
218
|
+
|
219
|
+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
220
|
+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
221
|
+
if (seq_pos[s0].empty()) {
|
222
|
+
continue;
|
223
|
+
}
|
224
|
+
|
225
|
+
std::stringstream ss;
|
226
|
+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
227
|
+
if (seq_cpl[s0][s1]) {
|
228
|
+
ss << s1 << " ";
|
229
|
+
}
|
230
|
+
}
|
231
|
+
|
232
|
+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
233
|
+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
234
|
+
}
|
235
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
189
236
|
}
|
190
|
-
return ubatch;
|
191
|
-
}
|
192
237
|
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
this->n_embd = n_embd;
|
197
|
-
this->logits_all = logits_all;
|
238
|
+
//
|
239
|
+
// consistency checks
|
240
|
+
//
|
198
241
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
242
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
243
|
+
if (seq_pos[s].empty()) {
|
244
|
+
continue;
|
245
|
+
}
|
246
|
+
|
247
|
+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
248
|
+
|
249
|
+
if (p0 >= 0) {
|
250
|
+
bool ok = true;
|
251
|
+
|
252
|
+
if (batch.token) {
|
253
|
+
if (seq_pos_min(s) != p0 + 1) {
|
254
|
+
ok = false;
|
255
|
+
}
|
256
|
+
} else {
|
257
|
+
assert(batch.embd);
|
258
|
+
|
259
|
+
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
260
|
+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
261
|
+
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
262
|
+
ok = false;
|
263
|
+
}
|
264
|
+
}
|
265
|
+
|
266
|
+
if (!ok) {
|
267
|
+
LLAMA_LOG_ERROR(
|
268
|
+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
269
|
+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
270
|
+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
271
|
+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
272
|
+
__func__, s, s, p0, s, seq_pos_min(s));
|
273
|
+
|
274
|
+
return false;
|
275
|
+
}
|
276
|
+
}
|
277
|
+
|
278
|
+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
279
|
+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
280
|
+
return false;
|
281
|
+
}
|
282
|
+
}
|
283
|
+
|
284
|
+
if (memory) {
|
285
|
+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
286
|
+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
287
|
+
if (seq_cpl[s0][s1]) {
|
288
|
+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
289
|
+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
290
|
+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
|
291
|
+
return false;
|
237
292
|
}
|
238
|
-
// no pos, sort by id
|
239
|
-
return a < b;
|
240
293
|
}
|
241
|
-
// shared prompts go first
|
242
|
-
return n_seq_a > n_seq_b;
|
243
294
|
}
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
295
|
+
}
|
296
|
+
}
|
297
|
+
|
298
|
+
// disallow partial sequence sub-sets:
|
299
|
+
//
|
300
|
+
// invalid: x
|
301
|
+
// i: 0 1 2 ...
|
302
|
+
// ---------------------------------------
|
303
|
+
// seq_id[i][0]: 0 0 1
|
304
|
+
// seq_id[i][1]: 1 1 2
|
305
|
+
// seq_id[i][2]: 2
|
306
|
+
//
|
307
|
+
// disallow decreasing sequence positions:
|
308
|
+
//
|
309
|
+
// invalid: x
|
310
|
+
// i: 0 1 2 3 4 5 6 ...
|
311
|
+
// ---------------------------------------
|
312
|
+
// pos[i]: 4 5 0 1 6 2 3
|
313
|
+
// seq_id[i][0]: 0 0 1 1 0 1 0
|
314
|
+
//
|
315
|
+
{
|
316
|
+
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
317
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
318
|
+
cur_seq_set[s].set();
|
319
|
+
}
|
320
|
+
|
321
|
+
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
322
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
323
|
+
cur_seq_pos[s] = -1;
|
324
|
+
}
|
325
|
+
|
326
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
327
|
+
const llama_pos pos = batch.pos[i];
|
328
|
+
|
329
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
330
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
331
|
+
|
332
|
+
cur_seq_set[seq_id] &= seq_set[i];
|
333
|
+
|
334
|
+
if (cur_seq_set[seq_id].none()) {
|
335
|
+
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
|
336
|
+
return false;
|
337
|
+
}
|
338
|
+
|
339
|
+
if (pos < cur_seq_pos[seq_id]) {
|
340
|
+
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
|
341
|
+
return false;
|
258
342
|
}
|
259
343
|
}
|
260
|
-
|
261
|
-
|
262
|
-
|
344
|
+
}
|
345
|
+
}
|
346
|
+
|
347
|
+
split_reset();
|
348
|
+
|
349
|
+
return true;
|
350
|
+
}
|
351
|
+
|
352
|
+
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
|
353
|
+
const uint32_t n_tokens = n_seq_tokens*n_seqs;
|
354
|
+
|
355
|
+
clear();
|
356
|
+
split_reset();
|
357
|
+
|
358
|
+
ubatches.emplace_back();
|
359
|
+
|
360
|
+
auto & ubatch = ubatches.back();
|
361
|
+
|
362
|
+
ubatch.token .resize(n_tokens);
|
363
|
+
ubatch.embd .clear();
|
364
|
+
ubatch.pos .resize(n_tokens);
|
365
|
+
ubatch.n_seq_id .resize(n_tokens);
|
366
|
+
ubatch.seq_id .resize(n_tokens);
|
367
|
+
ubatch.seq_id_unq.resize(0);
|
368
|
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
369
|
+
ubatch.output .resize(n_tokens);
|
370
|
+
|
371
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
372
|
+
ubatch.seq_idx[s] = s;
|
373
|
+
ubatch.seq_id_unq.push_back(s);
|
374
|
+
}
|
375
|
+
|
376
|
+
llama_ubatch res {
|
377
|
+
/*.equal_seqs =*/ true,
|
378
|
+
/*.n_tokens =*/ n_tokens,
|
379
|
+
/*.n_seq_tokens =*/ n_seq_tokens,
|
380
|
+
/*.n_seqs =*/ n_seqs,
|
381
|
+
/*.n_seqs_unq =*/ n_seqs,
|
382
|
+
|
383
|
+
/*.token =*/ ubatch.token.data(),
|
384
|
+
/*.embd =*/ nullptr,
|
385
|
+
/*.pos =*/ ubatch.pos.data(),
|
386
|
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
387
|
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
388
|
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
389
|
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
390
|
+
/*.output =*/ ubatch.output.data(),
|
391
|
+
};
|
392
|
+
|
393
|
+
return res;
|
394
|
+
}
|
395
|
+
|
396
|
+
const llama_batch & llama_batch_allocr::get_batch() const {
|
397
|
+
return batch;
|
398
|
+
}
|
399
|
+
|
400
|
+
uint32_t llama_batch_allocr::get_n_tokens() const {
|
401
|
+
return batch.n_tokens;
|
402
|
+
}
|
403
|
+
|
404
|
+
uint32_t llama_batch_allocr::get_n_outputs() const {
|
405
|
+
return n_outputs;
|
406
|
+
}
|
407
|
+
|
408
|
+
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
409
|
+
return out_ids;
|
410
|
+
}
|
411
|
+
|
412
|
+
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
413
|
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
414
|
+
}
|
415
|
+
|
416
|
+
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
417
|
+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
418
|
+
}
|
419
|
+
|
420
|
+
void llama_batch_allocr::split_reset() {
|
421
|
+
out_ids.clear();
|
422
|
+
|
423
|
+
used.clear();
|
424
|
+
used.resize(get_n_tokens(), false);
|
425
|
+
|
426
|
+
ubatches.clear();
|
427
|
+
}
|
428
|
+
|
429
|
+
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
430
|
+
// find the first unused token
|
431
|
+
uint32_t cur_idx = 0;
|
432
|
+
while (cur_idx < used.size() && used[cur_idx]) {
|
433
|
+
++cur_idx;
|
434
|
+
}
|
435
|
+
|
436
|
+
// we are done
|
437
|
+
if (cur_idx >= used.size()) {
|
438
|
+
return {};
|
439
|
+
}
|
440
|
+
|
441
|
+
std::vector<int32_t> idxs;
|
442
|
+
|
443
|
+
while (true) {
|
444
|
+
idxs.push_back(cur_idx);
|
445
|
+
|
446
|
+
used[cur_idx] = true;
|
447
|
+
|
448
|
+
++cur_idx;
|
449
|
+
|
450
|
+
if (cur_idx >= used.size()) {
|
451
|
+
break;
|
452
|
+
}
|
453
|
+
|
454
|
+
if (idxs.size() >= n_ubatch) {
|
455
|
+
break;
|
456
|
+
}
|
457
|
+
}
|
458
|
+
|
459
|
+
return ubatch_add(idxs, idxs.size(), false);
|
460
|
+
}
|
461
|
+
|
462
|
+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
463
|
+
std::vector<seq_set_t> cur_seq_set;
|
464
|
+
|
465
|
+
// determine the non-overlapping sequence sets participating in this ubatch
|
466
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
467
|
+
if (used[i]) {
|
468
|
+
continue;
|
469
|
+
}
|
470
|
+
|
471
|
+
bool add = true;
|
472
|
+
|
473
|
+
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
|
474
|
+
// no overlap with existing sequence sets:
|
475
|
+
if (!(cur_seq_set[s] & seq_set[i]).none()) {
|
476
|
+
add = false;
|
477
|
+
break;
|
478
|
+
}
|
479
|
+
}
|
480
|
+
|
481
|
+
if (add) {
|
482
|
+
cur_seq_set.push_back(seq_set[i]);
|
483
|
+
|
484
|
+
if (cur_seq_set.size() > n_ubatch) {
|
485
|
+
break;
|
263
486
|
}
|
264
487
|
}
|
265
|
-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
266
|
-
seq.push_back(new_seq);
|
267
|
-
last_seq = &seq.back();
|
268
488
|
}
|
269
489
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
490
|
+
const uint32_t n_seqs = cur_seq_set.size();
|
491
|
+
|
492
|
+
// we are done
|
493
|
+
if (n_seqs == 0) {
|
494
|
+
return {};
|
495
|
+
}
|
496
|
+
|
497
|
+
// the current batch index of each sequence set
|
498
|
+
std::vector<int32_t> cur_idx(n_seqs, 0);
|
499
|
+
|
500
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
501
|
+
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
|
502
|
+
++cur_idx[s];
|
503
|
+
}
|
504
|
+
}
|
505
|
+
|
506
|
+
// the list of batch indices for each sequence set
|
507
|
+
// at the end we will concat these to get the final ubatch
|
508
|
+
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
|
509
|
+
|
510
|
+
while (true) {
|
511
|
+
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
|
512
|
+
// if we haven't reached n_ubatch
|
513
|
+
bool can_expand = true;
|
514
|
+
|
515
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
516
|
+
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
|
517
|
+
can_expand = false;
|
518
|
+
break;
|
277
519
|
}
|
278
|
-
|
520
|
+
}
|
521
|
+
|
522
|
+
if (!can_expand) {
|
523
|
+
break;
|
524
|
+
}
|
525
|
+
|
526
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
527
|
+
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
|
528
|
+
|
529
|
+
idxs_per_seq[s].push_back(idx);
|
530
|
+
|
531
|
+
used[idx] = true;
|
532
|
+
|
533
|
+
++cur_idx[s];
|
534
|
+
}
|
535
|
+
|
536
|
+
if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
|
537
|
+
break;
|
538
|
+
}
|
539
|
+
}
|
540
|
+
|
541
|
+
// concat the per-sequence-set lists
|
542
|
+
std::vector<int32_t> idxs;
|
543
|
+
|
544
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
545
|
+
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
|
546
|
+
}
|
547
|
+
|
548
|
+
return ubatch_add(idxs, n_seqs, true);
|
279
549
|
}
|
280
550
|
|
281
|
-
llama_batch_allocr::
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
551
|
+
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
552
|
+
// find the first unused token
|
553
|
+
uint32_t cur_idx = 0;
|
554
|
+
while (cur_idx < used.size() && used[cur_idx]) {
|
555
|
+
++cur_idx;
|
556
|
+
}
|
557
|
+
|
558
|
+
// we are done
|
559
|
+
if (cur_idx >= used.size()) {
|
560
|
+
return {};
|
561
|
+
}
|
562
|
+
|
563
|
+
// this is the starting sequence set
|
564
|
+
// we allow adding tokens only if their sequence set is a subset of the current sequence set
|
565
|
+
auto cur_seq_set = seq_set[cur_idx];
|
566
|
+
|
567
|
+
std::vector<int32_t> idxs;
|
568
|
+
|
569
|
+
while (true) {
|
570
|
+
idxs.push_back(cur_idx);
|
571
|
+
|
572
|
+
used[cur_idx] = true;
|
573
|
+
|
574
|
+
if (idxs.size() >= n_ubatch) {
|
575
|
+
break;
|
289
576
|
}
|
290
|
-
|
577
|
+
|
578
|
+
do {
|
579
|
+
++cur_idx;
|
580
|
+
} while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
|
581
|
+
|
582
|
+
if (cur_idx == get_n_tokens()) {
|
583
|
+
break;
|
584
|
+
}
|
585
|
+
|
586
|
+
cur_seq_set = seq_set[cur_idx];
|
291
587
|
}
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
588
|
+
|
589
|
+
return ubatch_add(idxs, 1, true);
|
590
|
+
}
|
591
|
+
|
592
|
+
void llama_batch_allocr::clear() {
|
593
|
+
n_outputs = 0;
|
594
|
+
|
595
|
+
batch = {};
|
596
|
+
|
597
|
+
pos .clear();
|
598
|
+
n_seq_id .clear();
|
599
|
+
seq_id .clear();
|
600
|
+
seq_id_unq.clear();
|
601
|
+
output .clear();
|
602
|
+
|
603
|
+
for (auto & cur : seq_pos) {
|
604
|
+
cur.clear();
|
605
|
+
}
|
606
|
+
|
607
|
+
for (auto & cur : seq_cpl) {
|
608
|
+
std::fill(cur.begin(), cur.end(), false);
|
609
|
+
}
|
610
|
+
|
611
|
+
seq_set.clear();
|
612
|
+
|
613
|
+
seq_set_map.clear();
|
614
|
+
|
615
|
+
std::fill(seq_idx.begin(), seq_idx.end(), -1);
|
616
|
+
}
|
617
|
+
|
618
|
+
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
|
619
|
+
const uint32_t n_tokens = idxs.size();
|
620
|
+
|
621
|
+
assert(n_tokens%n_seqs == 0);
|
622
|
+
|
623
|
+
ubatches.emplace_back();
|
624
|
+
|
625
|
+
auto & ubatch = ubatches.back();
|
626
|
+
|
627
|
+
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
628
|
+
|
629
|
+
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
630
|
+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
631
|
+
|
632
|
+
ubatch.token .resize(n_tokens);
|
633
|
+
ubatch.embd .resize(n_embd_all);
|
634
|
+
ubatch.pos .resize(n_pos_all);
|
635
|
+
ubatch.n_seq_id .resize(n_tokens);
|
636
|
+
ubatch.seq_id .resize(n_tokens);
|
637
|
+
ubatch.seq_id_unq.resize(0);
|
638
|
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
639
|
+
ubatch.output .resize(n_tokens);
|
640
|
+
|
641
|
+
seq_set_t seq_set_unq;
|
642
|
+
|
643
|
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
644
|
+
if (batch.token) {
|
645
|
+
ubatch.token[i] = batch.token[idxs[i]];
|
646
|
+
}
|
647
|
+
|
648
|
+
if (batch.embd) {
|
649
|
+
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
650
|
+
}
|
651
|
+
|
652
|
+
for (int j = 0; j < n_pos_cur; ++j) {
|
653
|
+
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
654
|
+
}
|
655
|
+
|
656
|
+
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
657
|
+
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
|
658
|
+
ubatch.output[i] = batch.logits[idxs[i]];
|
659
|
+
|
660
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
661
|
+
seq_set_unq.set(ubatch.seq_id[i][s]);
|
662
|
+
}
|
663
|
+
|
664
|
+
if (ubatch.output[i]) {
|
665
|
+
out_ids.push_back(idxs[i]);
|
296
666
|
}
|
297
|
-
batch.n_seq_id = n_seq_id.data();
|
298
667
|
}
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
668
|
+
|
669
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
670
|
+
if (seq_set_unq.test(s)) {
|
671
|
+
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
672
|
+
ubatch.seq_id_unq.push_back(s);
|
304
673
|
}
|
305
|
-
batch.seq_id = seq_id.data();
|
306
674
|
}
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
675
|
+
|
676
|
+
llama_ubatch res {
|
677
|
+
/*.equal_seqs =*/ equal_seqs,
|
678
|
+
/*.n_tokens =*/ n_tokens,
|
679
|
+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
680
|
+
/*.n_seqs =*/ n_seqs,
|
681
|
+
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
|
682
|
+
|
683
|
+
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
|
684
|
+
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
|
685
|
+
/*.pos =*/ ubatch.pos.data(),
|
686
|
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
687
|
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
688
|
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
689
|
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
690
|
+
/*.output =*/ ubatch.output.data(),
|
691
|
+
};
|
692
|
+
|
693
|
+
if (debug > 0) {
|
694
|
+
LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
|
695
|
+
|
696
|
+
ubatch_print(res, debug);
|
697
|
+
}
|
698
|
+
|
699
|
+
return res;
|
700
|
+
}
|
701
|
+
|
702
|
+
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
703
|
+
if (debug > 0) {
|
704
|
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
705
|
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
706
|
+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
707
|
+
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
708
|
+
LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
|
709
|
+
|
710
|
+
std::stringstream ss_seq_id_unq;
|
711
|
+
std::stringstream ss_seq_idx;
|
712
|
+
|
713
|
+
ss_seq_id_unq << "[ ";
|
714
|
+
ss_seq_idx << "[";
|
715
|
+
|
716
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
717
|
+
ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
|
718
|
+
}
|
719
|
+
|
720
|
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
721
|
+
if (ubatch.seq_idx[s] >= 0) {
|
722
|
+
ss_seq_idx << ubatch.seq_idx[s]%10;
|
723
|
+
} else {
|
724
|
+
ss_seq_idx << ".";
|
725
|
+
}
|
726
|
+
}
|
727
|
+
|
728
|
+
ss_seq_id_unq << "]";
|
729
|
+
ss_seq_idx << "]";
|
730
|
+
|
731
|
+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
|
732
|
+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
|
733
|
+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
|
734
|
+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
|
735
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
|
736
|
+
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
|
737
|
+
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
|
738
|
+
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
|
739
|
+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
740
|
+
|
741
|
+
if (debug > 1) {
|
742
|
+
int seq_id_max = 0;
|
743
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
744
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
745
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
746
|
+
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
|
747
|
+
}
|
748
|
+
}
|
749
|
+
}
|
750
|
+
++seq_id_max;
|
751
|
+
|
752
|
+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
|
753
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
754
|
+
std::vector<int8_t> seq_id(seq_id_max);
|
755
|
+
|
756
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
757
|
+
seq_id[ubatch.seq_id[i][s]] = 1;
|
758
|
+
}
|
759
|
+
|
760
|
+
std::stringstream ss;
|
761
|
+
for (int s = 0; s < seq_id_max; ++s) {
|
762
|
+
if (seq_id[s]) {
|
763
|
+
ss << s%10;
|
764
|
+
} else {
|
765
|
+
ss << ".";
|
766
|
+
}
|
767
|
+
}
|
768
|
+
|
769
|
+
if (ubatch.token) {
|
770
|
+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
771
|
+
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
|
772
|
+
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
773
|
+
} else {
|
774
|
+
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
775
|
+
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
776
|
+
}
|
777
|
+
}
|
778
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
779
|
+
}
|
311
780
|
}
|
312
781
|
}
|
313
782
|
|
@@ -319,25 +788,25 @@ struct llama_batch llama_batch_get_one(
|
|
319
788
|
llama_token * tokens,
|
320
789
|
int32_t n_tokens) {
|
321
790
|
return {
|
322
|
-
/*n_tokens
|
323
|
-
/*tokens
|
324
|
-
/*embd
|
325
|
-
/*pos
|
326
|
-
/*n_seq_id
|
327
|
-
/*seq_id
|
328
|
-
/*logits
|
791
|
+
/*n_tokens =*/ n_tokens,
|
792
|
+
/*tokens =*/ tokens,
|
793
|
+
/*embd =*/ nullptr,
|
794
|
+
/*pos =*/ nullptr,
|
795
|
+
/*n_seq_id =*/ nullptr,
|
796
|
+
/*seq_id =*/ nullptr,
|
797
|
+
/*logits =*/ nullptr,
|
329
798
|
};
|
330
799
|
}
|
331
800
|
|
332
801
|
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
333
802
|
llama_batch batch = {
|
334
|
-
/*n_tokens
|
335
|
-
/*tokens
|
336
|
-
/*embd
|
337
|
-
/*pos
|
338
|
-
/*n_seq_id
|
339
|
-
/*seq_id
|
340
|
-
/*logits
|
803
|
+
/*n_tokens =*/ 0,
|
804
|
+
/*tokens =*/ nullptr,
|
805
|
+
/*embd =*/ nullptr,
|
806
|
+
/*pos =*/ nullptr,
|
807
|
+
/*n_seq_id =*/ nullptr,
|
808
|
+
/*seq_id =*/ nullptr,
|
809
|
+
/*logits =*/ nullptr,
|
341
810
|
};
|
342
811
|
|
343
812
|
if (embd) {
|