@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -2,86 +2,44 @@
2
2
 
3
3
  #include "llama.h"
4
4
 
5
+ #include "llama-cparams.h"
6
+
5
7
  #include <array>
6
8
  #include <vector>
7
9
  #include <set>
10
+ #include <bitset>
11
+ #include <unordered_map>
8
12
 
9
- // very similar to llama_batch,
10
- // but has more metadata about sequences
13
+ // keep this struct lightweight
14
+ // it points to data in `llama_batch_allocr`
11
15
  struct llama_ubatch {
12
16
  bool equal_seqs;
13
17
  // TODO: whole_seqs for embeddings?
14
18
 
15
19
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
16
- uint32_t n_seq_tokens; // tokens per sequence
17
- uint32_t n_seqs;
18
-
19
- llama_token * token; // [n_tokens]
20
- float * embd; // [n_embd, n_tokens]
21
- llama_pos * pos; // [n_tokens]
22
- int32_t * n_seq_id; // [n_seqs]
23
- llama_seq_id ** seq_id; // [n_seqs]
24
- int8_t * output; // [n_tokens]
25
- };
26
-
27
- struct llama_sbatch_seq {
28
- int32_t n_seq_id;
29
-
30
- llama_seq_id * seq_id;
31
-
32
- size_t offset;
33
- size_t length;
34
- };
35
-
36
- // sequence-length-aware batch splitting
37
- struct llama_sbatch {
38
- // tokens left in this batch
39
- size_t n_tokens;
40
-
41
- size_t n_embd;
42
-
43
- // sorted indices into the batch
44
- std::vector<int64_t> ids;
45
- // batch indices of the output
46
- std::vector<int64_t> out_ids;
47
- std::vector<llama_sbatch_seq> seq;
48
-
49
- const llama_batch * batch = nullptr;
50
-
51
- // buffers for the ubatches
52
- // TODO: very hacky, this needs a complete rework
53
- struct ubatch_data {
54
- std::vector<llama_token> token;
55
- std::vector<float> embd;
56
- std::vector<llama_pos> pos;
57
- std::vector<int32_t> n_seq_id;
58
- std::vector<llama_seq_id *> seq_id;
59
- std::vector<int8_t> output;
60
- };
61
-
62
- std::vector<ubatch_data> udatas;
63
-
64
- llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
65
-
66
- void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
67
-
68
- // simple split, unknown number of sequences of unequal lengths
69
- llama_ubatch split_simple(size_t n_ubatch);
70
-
71
- // make batches of equal-length sequences
72
- llama_ubatch split_equal(size_t n_ubatch);
73
-
74
- // sequence-wise split
75
- llama_ubatch split_seq(size_t n_ubatch);
76
-
77
- llama_sbatch() = default;
78
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
20
+ uint32_t n_seq_tokens; // tokens per sequence set
21
+ uint32_t n_seqs; // sequence sets in the ubatch
22
+ uint32_t n_seqs_unq; // unique sequence ids in the ubatch
23
+
24
+ // seq_id_unq: unique sequence ids in the ubatch
25
+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26
+ // used for extracting sequence pooled embeddings
27
+
28
+ // // size | idx | val
29
+ llama_token * token; // [n_tokens] | i | id, token
30
+ float * embd; // [n_embd, n_tokens] | i | embd
31
+ llama_pos * pos; // [n_tokens] | i | pos
32
+ int32_t * n_seq_id; // [n_tokens] | i | -
33
+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34
+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
+ int8_t * output; // [n_tokens] | i | -
79
37
  };
80
38
 
81
- // a helper for sanitizing and fulfilling a batch
39
+ // a helper for sanitizing, fulfilling and splitting a batch
82
40
  class llama_batch_allocr {
83
41
  public:
84
- llama_batch_allocr();
42
+ llama_batch_allocr(uint32_t n_pos_per_embd);
85
43
 
86
44
  // sanitize and auto-gen missing data in the input batch
87
45
  // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@@ -89,20 +47,57 @@ public:
89
47
  const llama_batch & batch_inp,
90
48
  const llama_vocab & vocab,
91
49
  const llama_memory_i * memory,
92
- bool embd_all);
50
+ uint32_t n_embd,
51
+ bool output_all);
93
52
 
94
53
  const llama_batch & get_batch() const;
95
54
 
55
+ uint32_t get_n_tokens() const;
96
56
  uint32_t get_n_outputs() const;
97
57
 
58
+ // the array of output indices in the order they were encountered during the ubatch splitting
59
+ std::vector<int32_t> & get_out_ids();
60
+
61
+ // min/max positions of each sequence in the current ubatch
98
62
  llama_pos seq_pos_min(llama_seq_id seq_id) const;
99
63
  llama_pos seq_pos_max(llama_seq_id seq_id) const;
100
64
 
65
+ // call once before splitting the batch to reset the internal state
66
+ void split_reset();
67
+
68
+ // simple split, unknown number of sequence sets of unequal lengths
69
+ llama_ubatch split_simple(uint32_t n_ubatch);
70
+
71
+ // make ubatches of equal-length sequences sets
72
+ llama_ubatch split_equal(uint32_t n_ubatch);
73
+
74
+ // sequence-set-wise split - each ubatch contains a single sequence-set
75
+ llama_ubatch split_seq(uint32_t n_ubatch);
76
+
77
+ // a helper method for creating a well-defined ubatch of tokens
78
+ // TODO: support embeddings if needed in the future
79
+ llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
80
+
101
81
  private:
102
82
  void clear();
103
83
 
84
+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
85
+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
86
+ llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
87
+
88
+ // for debugging, start with LLAMA_BATCH_DEBUG=2
89
+ void ubatch_print(const llama_ubatch & ubatch, int debug);
90
+
104
91
  llama_batch batch;
105
92
 
93
+ // only for debugging purposes
94
+ const llama_vocab * vocab;
95
+
96
+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98
+ const uint32_t n_pos_per_embd;
99
+
100
+ uint32_t n_embd;
106
101
  uint32_t n_outputs;
107
102
 
108
103
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -110,10 +105,43 @@ private:
110
105
  std::vector<llama_pos> pos;
111
106
  std::vector<int32_t> n_seq_id;
112
107
  std::vector<llama_seq_id *> seq_id;
108
+ std::vector<llama_seq_id> seq_id_unq;
109
+ std::vector<int32_t> seq_idx;
113
110
  std::vector<int8_t> output;
114
111
 
115
- std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
- std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
112
+ using pos_set_t = std::set<llama_pos>;
113
+ using seq_cpl_t = std::vector<bool>;
114
+
115
+ std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
+
118
+ using idx_vec_t = std::vector<int32_t>;
119
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
120
+
121
+ std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
122
+
123
+ std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
124
+
125
+ // batch indices of the output
126
+ std::vector<int32_t> out_ids;
127
+
128
+ // used[i] indicates if token i has already been used in a previous ubatch
129
+ std::vector<bool> used;
130
+
131
+ // llama_ubatch points to this data:
132
+ struct ubatch {
133
+ std::vector<llama_token> token;
134
+ std::vector<float> embd;
135
+ std::vector<llama_pos> pos;
136
+ std::vector<int32_t> n_seq_id;
137
+ std::vector<llama_seq_id *> seq_id;
138
+ std::vector<llama_seq_id> seq_id_unq;
139
+ std::vector<int32_t> seq_idx;
140
+ std::vector<int8_t> output;
141
+ };
142
+
143
+ // current splitting state:
144
+ std::vector<ubatch> ubatches;
117
145
 
118
146
  int debug;
119
147
  };
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
528
528
  }
529
529
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
530
530
  // this template requires the model to have "\n\n" as EOT token
531
- for (auto message : chat) {
532
- std::string role(message->role);
533
- if (role == "user") {
534
- ss << "User: " << message->content << "\n\nAssistant:";
535
- } else {
536
- ss << message->content << "\n\n";
531
+ for (size_t i = 0; i < chat.size(); i++) {
532
+ std::string role(chat[i]->role);
533
+ if (role == "system") {
534
+ ss << "System: " << trim(chat[i]->content) << "\n\n";
535
+ } else if (role == "user") {
536
+ ss << "User: " << trim(chat[i]->content) << "\n\n";
537
+ if (i == chat.size() - 1) {
538
+ ss << "Assistant:";
539
+ }
540
+ } else if (role == "assistant") {
541
+ ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
537
542
  }
538
543
  }
539
544
  } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {