@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,320 +1,782 @@
1
1
  #include "llama-batch.h"
2
2
 
3
+ #include "llama-impl.h"
4
+ #include "llama-vocab.h"
5
+ #include "llama-memory.h"
6
+
3
7
  #include <cassert>
4
8
  #include <cstring>
5
9
  #include <algorithm>
10
+ #include <sstream>
6
11
 
7
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
8
- // clear empty sequences
9
- // the previous ubatch is assumed to be gone,
10
- // so nothing should refer to values in these sequences anymore.
11
- for (size_t i = seq.size(); i-- > 0;) {
12
- if (seq[i].length == 0) {
13
- seq.pop_back();
14
- } else {
15
- break;
12
+ llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
13
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
14
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
15
+
16
+ seq_pos.resize(LLAMA_MAX_SEQ);
17
+ seq_cpl.resize(LLAMA_MAX_SEQ);
18
+ for (auto & cur : seq_cpl) {
19
+ cur.resize(LLAMA_MAX_SEQ);
20
+ }
21
+
22
+ seq_idx.resize(LLAMA_MAX_SEQ, -1);
23
+ }
24
+
25
+ bool llama_batch_allocr::init(
26
+ const llama_batch & batch_inp,
27
+ const llama_vocab & vocab,
28
+ const llama_memory_i * memory,
29
+ uint32_t n_embd,
30
+ bool output_all) {
31
+ clear();
32
+
33
+ batch = batch_inp;
34
+
35
+ this->vocab = &vocab;
36
+
37
+ GGML_ASSERT(batch.n_tokens > 0);
38
+
39
+ //
40
+ // validate input batch
41
+ //
42
+
43
+ if (batch.token) {
44
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
45
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
46
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
47
+ return false;
48
+ }
16
49
  }
17
50
  }
18
51
 
19
- udatas.push_back({});
52
+ if (batch.seq_id) {
53
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
54
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
57
+ return false;
58
+ }
59
+ }
60
+ }
61
+ }
20
62
 
21
- auto & udata = udatas.back();
63
+ //
64
+ // auto-generate missing fields
65
+ //
22
66
 
23
- udata.token.resize(!has_embd ? n_ubatch : 0);
24
- udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
25
- udata.pos.resize(n_ubatch);
26
- udata.n_seq_id.resize(n_ubatch);
27
- udata.seq_id.resize(n_ubatch);
28
- udata.output.resize(n_ubatch);
67
+ if (!batch.n_seq_id) {
68
+ n_seq_id.resize(batch.n_tokens);
69
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
70
+ n_seq_id[i] = seq_id_0.size();
71
+ }
72
+ batch.n_seq_id = n_seq_id.data();
73
+ }
29
74
 
30
- llama_ubatch ubatch = {
31
- /*equal_seqs =*/ true,
32
- /*n_tokens =*/ 0,
33
- /*n_seq_tokens =*/ 0,
34
- /*n_seqs =*/ 0,
35
- /*token =*/ !has_embd ? udata.token.data() : nullptr,
36
- /*embd =*/ has_embd ? udata.embd.data() : nullptr,
37
- /*pos =*/ udata.pos.data(),
38
- /*n_seq_id =*/ udata.n_seq_id.data(),
39
- /*seq_id =*/ udata.seq_id.data(),
40
- /*output =*/ udata.output.data(),
41
- };
75
+ if (!batch.seq_id) {
76
+ seq_id.resize(batch.n_tokens + 1);
77
+ seq_id[batch.n_tokens] = NULL;
78
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
79
+ seq_id[i] = seq_id_0.data();
80
+ }
81
+ batch.seq_id = seq_id.data();
82
+ }
42
83
 
43
- return ubatch;
44
- }
84
+ if (!batch.pos) {
85
+ pos.resize(batch.n_tokens);
45
86
 
46
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
47
- GGML_ASSERT(batch != nullptr);
48
- GGML_ASSERT(length <= seq.length);
49
- // Can only add sequences of equal lengths to a batch,
50
- // otherwise it isn't clear to which sequence a token belongs
51
- GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
52
- GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
53
- // NOTE: loops are separated for cache-friendliness
54
- if (batch->token) {
55
- if (ubatch.equal_seqs) {
56
- for (size_t i = 0; i < length; ++i) {
57
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
87
+ // initialize the starting position for each sequence based on the positions in the memory
88
+ llama_pos p0[LLAMA_MAX_SEQ];
89
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
90
+ if (!memory) {
91
+ // if no memory -> start from 0
92
+ p0[s] = 0;
93
+ } else {
94
+ p0[s] = memory->seq_pos_max(s) + 1;
58
95
  }
59
- } else {
60
- // simple split
61
- ubatch.token = batch->token + seq.offset;
62
96
  }
63
- } else {
64
- ubatch.token = nullptr;
65
- }
66
- if (batch->embd) {
67
- if (ubatch.equal_seqs) {
68
- for (size_t i = 0; i < length; ++i) {
69
- memcpy(
70
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
71
- batch->embd + (n_embd * ids[seq.offset + i]),
72
- n_embd * sizeof(float)
73
- );
97
+
98
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
99
+ const llama_seq_id seq_id = batch.seq_id[i][0];
100
+
101
+ pos[i] = p0[seq_id];
102
+
103
+ // update the starting position for all sequences that are assigned to the this token
104
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
105
+ const llama_seq_id seq_id = batch.seq_id[i][s];
106
+
107
+ p0[seq_id] = pos[i] + 1;
74
108
  }
109
+ }
110
+
111
+ batch.pos = pos.data();
112
+ }
113
+
114
+ if (!batch.logits) {
115
+ if (output_all) {
116
+ // return the output for all tokens
117
+ output.resize(batch.n_tokens, true);
75
118
  } else {
76
- // simple split
77
- ubatch.embd = batch->embd + (n_embd * seq.offset);
119
+ // return the output only for the last token
120
+ output.resize(batch.n_tokens, false);
121
+ output[output.size() - 1] = true;
78
122
  }
79
- } else {
80
- ubatch.embd = nullptr;
123
+
124
+ batch.logits = output.data();
125
+ } else if (output_all) {
126
+ bool warn = false;
127
+
128
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
129
+ if (batch.logits[i] == 0) {
130
+ warn = true;
131
+ }
132
+ }
133
+
134
+ if (warn) {
135
+ LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
136
+
137
+ output.resize(batch.n_tokens, true);
138
+ batch.logits = output.data();
139
+ }
140
+ }
141
+
142
+ //
143
+ // compute stats
144
+ //
145
+
146
+ this->n_embd = n_embd;
147
+
148
+ // count the outputs in this batch
149
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
+ n_outputs += batch.logits[i] != 0;
81
151
  }
82
- if (ubatch.equal_seqs) {
83
- for (size_t i = 0; i < length; ++i) {
84
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
152
+
153
+ // determine coupled sequences
154
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
156
+ const llama_seq_id s0 = batch.seq_id[i][0];
157
+
158
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
159
+ const llama_seq_id s1 = batch.seq_id[i][s];
160
+
161
+ seq_pos[s1].insert(batch.pos[i]);
162
+
163
+ if (s > 0) {
164
+ // mark that sequence s1 is coupled to s0
165
+ seq_cpl[s1][s0] = true;
166
+
167
+ // note: tracking the other way around is not necessary for now
168
+ //seq_cpl[s0][s1] = true;
169
+ }
85
170
  }
86
- } else {
87
- // simple split
88
- ubatch.pos = batch->pos + seq.offset;
89
171
  }
90
- if (ubatch.equal_seqs) {
91
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
92
- if (seq.seq_id) {
93
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
172
+
173
+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
174
+ {
175
+ seq_set_t seq_set_unq;
176
+
177
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
178
+ seq_set_t cur;
179
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
180
+ const llama_seq_id seq_id = batch.seq_id[i][s];
181
+
182
+ cur .set(seq_id);
183
+ seq_set_unq.set(seq_id);
184
+ }
185
+
186
+ seq_set.push_back(cur);
187
+ seq_set_map[cur].push_back(i);
94
188
  }
95
- } else {
96
- // simple split
97
- if (batch->n_seq_id) {
98
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
99
- } else {
100
- for (size_t i = 0; i < length; ++i) {
101
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
189
+
190
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
191
+ if (seq_set_unq.test(s)) {
192
+ seq_idx[s] = seq_id_unq.size();
193
+ seq_id_unq.push_back(s);
102
194
  }
103
195
  }
104
- if (batch->seq_id) {
105
- ubatch.seq_id = batch->seq_id + seq.offset;
196
+ }
197
+
198
+ if (debug > 0) {
199
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
200
+
201
+ llama_ubatch ubatch {
202
+ /*.equal_seqs =*/ false,
203
+ /*.n_tokens =*/ (uint32_t) batch.n_tokens,
204
+ /*.n_seq_tokens =*/ (uint32_t) 1,
205
+ /*.n_seqs =*/ (uint32_t) batch.n_tokens,
206
+ /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
207
+ /*.token =*/ batch.token,
208
+ /*.embd =*/ batch.embd,
209
+ /*.pos =*/ batch.pos,
210
+ /*.n_seq_id =*/ batch.n_seq_id,
211
+ /*.seq_id =*/ batch.seq_id,
212
+ /*.seq_id_unq =*/ this->seq_id_unq.data(),
213
+ /*.seq_idx =*/ this->seq_idx.data(),
214
+ /*.output =*/ batch.logits,
215
+ };
216
+
217
+ ubatch_print(ubatch, debug);
218
+
219
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
220
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
221
+ if (seq_pos[s0].empty()) {
222
+ continue;
223
+ }
224
+
225
+ std::stringstream ss;
226
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
227
+ if (seq_cpl[s0][s1]) {
228
+ ss << s1 << " ";
229
+ }
230
+ }
231
+
232
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
233
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
106
234
  }
235
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
107
236
  }
108
- if (logits_all) {
109
- for (size_t i = 0; i < length; ++i) {
110
- ubatch.output[ubatch.n_tokens + i] = 1;
111
- out_ids.push_back(ids[seq.offset + i]);
237
+
238
+ //
239
+ // consistency checks
240
+ //
241
+
242
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
243
+ if (seq_pos[s].empty()) {
244
+ continue;
112
245
  }
113
- } else if (batch->logits) {
114
- if (ubatch.equal_seqs) {
115
- for (size_t i = 0; i < length; ++i) {
116
- size_t id = ids[seq.offset + i];
117
- int8_t is_output = batch->logits[id];
118
- ubatch.output[ubatch.n_tokens + i] = is_output;
119
- if (is_output) { out_ids.push_back(id); }
246
+
247
+ const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
248
+
249
+ if (p0 >= 0) {
250
+ bool ok = true;
251
+
252
+ if (batch.token) {
253
+ if (seq_pos_min(s) != p0 + 1) {
254
+ ok = false;
255
+ }
256
+ } else {
257
+ assert(batch.embd);
258
+
259
+ // for embeddings (typically used as vision input), we allow them to have repeating positions
260
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
261
+ if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
262
+ ok = false;
263
+ }
120
264
  }
121
- } else {
122
- // simple split
123
- ubatch.output = batch->logits + seq.offset;
124
- for (size_t i = 0; i < length; ++i) {
125
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
265
+
266
+ if (!ok) {
267
+ LLAMA_LOG_ERROR(
268
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
269
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
270
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
271
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
272
+ __func__, s, s, p0, s, seq_pos_min(s));
273
+
274
+ return false;
126
275
  }
127
276
  }
128
- } else {
129
- // only get last output
130
- for (size_t i = 0; i < length; ++i) {
131
- size_t id = ids[seq.offset + i];
132
- int8_t is_last = id == ids.size() - 1;
133
- ubatch.output[ubatch.n_tokens + i] = is_last;
134
- if (is_last) { out_ids.push_back(id); }
135
- }
136
- }
137
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
138
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
139
- }
140
- ubatch.n_tokens += length;
141
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
142
- seq.offset += length;
143
- seq.length -= length;
144
- n_tokens -= length;
145
- GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
146
- }
147
277
 
148
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
149
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
150
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
151
- ubatch.equal_seqs = false;
152
- if (!seq.empty()) {
153
- llama_sbatch_seq & s = seq[0];
154
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
155
- GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
156
- add_seq_to_ubatch(ubatch, s, length);
157
- }
158
- return ubatch;
159
- }
278
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
279
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
280
+ return false;
281
+ }
282
+ }
283
+
284
+ if (memory) {
285
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
286
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
287
+ if (seq_cpl[s0][s1]) {
288
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
289
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
290
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
291
+ return false;
292
+ }
293
+ }
294
+ }
295
+ }
296
+ }
297
+
298
+ // disallow partial sequence sub-sets:
299
+ //
300
+ // invalid: x
301
+ // i: 0 1 2 ...
302
+ // ---------------------------------------
303
+ // seq_id[i][0]: 0 0 1
304
+ // seq_id[i][1]: 1 1 2
305
+ // seq_id[i][2]: 2
306
+ //
307
+ // disallow decreasing sequence positions:
308
+ //
309
+ // invalid: x
310
+ // i: 0 1 2 3 4 5 6 ...
311
+ // ---------------------------------------
312
+ // pos[i]: 4 5 0 1 6 2 3
313
+ // seq_id[i][0]: 0 0 1 1 0 1 0
314
+ //
315
+ {
316
+ seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
318
+ cur_seq_set[s].set();
319
+ }
320
+
321
+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
323
+ cur_seq_pos[s] = -1;
324
+ }
325
+
326
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
327
+ const llama_pos pos = batch.pos[i];
328
+
329
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
330
+ const llama_seq_id seq_id = batch.seq_id[i][s];
160
331
 
161
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
162
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
163
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
164
- if (!seq.empty()) {
165
- size_t length = 0;
166
- size_t n_tokens_in_ubatch = 0;
167
- GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
168
- // smallest first, because it's easier to split this way;
169
- // starting from the end to pop in constant time.
170
- for (size_t i = seq.size(); i-- > 0;) {
171
- llama_sbatch_seq & s = seq[i];
172
- GGML_ASSERT(s.length > 0);
173
- if (length == 0) {
174
- length = s.length < n_ubatch ? s.length : n_ubatch;
332
+ cur_seq_set[seq_id] &= seq_set[i];
333
+
334
+ if (cur_seq_set[seq_id].none()) {
335
+ LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
336
+ return false;
337
+ }
338
+
339
+ if (pos < cur_seq_pos[seq_id]) {
340
+ LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
341
+ return false;
342
+ }
175
343
  }
176
- add_seq_to_ubatch(ubatch, s, length);
177
- n_tokens_in_ubatch += length;
178
- // shared prompts can't be mixed with any of their sequences,
179
- // so it's safer to compute them in their own ubatch
180
- if (s.n_seq_id > 1) { break; }
181
- // stop when there isn't enough space for another sequence
182
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
183
344
  }
184
345
  }
185
- return ubatch;
346
+
347
+ split_reset();
348
+
349
+ return true;
186
350
  }
187
351
 
188
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
190
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
191
- if (!seq.empty()) {
192
- llama_sbatch_seq & s = seq[seq.size() - 1];
193
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
194
- GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
195
- add_seq_to_ubatch(ubatch, s, length);
352
+ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
353
+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
354
+
355
+ clear();
356
+ split_reset();
357
+
358
+ ubatches.emplace_back();
359
+
360
+ auto & ubatch = ubatches.back();
361
+
362
+ ubatch.token .resize(n_tokens);
363
+ ubatch.embd .clear();
364
+ ubatch.pos .resize(n_tokens);
365
+ ubatch.n_seq_id .resize(n_tokens);
366
+ ubatch.seq_id .resize(n_tokens);
367
+ ubatch.seq_id_unq.resize(0);
368
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
369
+ ubatch.output .resize(n_tokens);
370
+
371
+ for (uint32_t s = 0; s < n_seqs; ++s) {
372
+ ubatch.seq_idx[s] = s;
373
+ ubatch.seq_id_unq.push_back(s);
196
374
  }
197
- return ubatch;
375
+
376
+ llama_ubatch res {
377
+ /*.equal_seqs =*/ true,
378
+ /*.n_tokens =*/ n_tokens,
379
+ /*.n_seq_tokens =*/ n_seq_tokens,
380
+ /*.n_seqs =*/ n_seqs,
381
+ /*.n_seqs_unq =*/ n_seqs,
382
+
383
+ /*.token =*/ ubatch.token.data(),
384
+ /*.embd =*/ nullptr,
385
+ /*.pos =*/ ubatch.pos.data(),
386
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
387
+ /*.seq_id =*/ ubatch.seq_id.data(),
388
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
389
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
390
+ /*.output =*/ ubatch.output.data(),
391
+ };
392
+
393
+ return res;
198
394
  }
199
395
 
200
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
201
- GGML_ASSERT(batch.n_tokens >= 0);
202
- this->batch = &batch;
203
- this->n_embd = n_embd;
204
- this->logits_all = logits_all;
396
+ const llama_batch & llama_batch_allocr::get_batch() const {
397
+ return batch;
398
+ }
399
+
400
+ uint32_t llama_batch_allocr::get_n_tokens() const {
401
+ return batch.n_tokens;
402
+ }
403
+
404
+ uint32_t llama_batch_allocr::get_n_outputs() const {
405
+ return n_outputs;
406
+ }
407
+
408
+ std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409
+ return out_ids;
410
+ }
205
411
 
206
- n_tokens = batch.n_tokens;
207
- ids.resize(n_tokens);
412
+ llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
413
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
414
+ }
415
+
416
+ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
417
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
418
+ }
419
+
420
+ void llama_batch_allocr::split_reset() {
208
421
  out_ids.clear();
209
- // TODO: reserve out_ids and seq
210
-
211
- for (size_t i = 0; i < n_tokens; ++i) {
212
- ids[i] = i;
213
- }
214
-
215
- if (simple_split) {
216
- seq.resize(1);
217
- llama_sbatch_seq & s = seq[0];
218
- s.n_seq_id = 0;
219
- s.seq_id = nullptr;
220
- s.offset = 0;
221
- s.length = n_tokens;
222
- return;
223
- }
224
-
225
- std::sort(ids.begin(), ids.end(),
226
- [&batch](size_t a, size_t b) {
227
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
228
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
229
- // sort by seq_id, then by pos
230
- if (n_seq_a == n_seq_b) {
231
- if (batch.seq_id) {
232
- for (int32_t i = 0; i < n_seq_a; ++i) {
233
- llama_seq_id seq_id_a = batch.seq_id[a][i];
234
- llama_seq_id seq_id_b = batch.seq_id[b][i];
235
- // smaller seq_ids go first
236
- if (seq_id_a != seq_id_b) {
237
- return seq_id_a < seq_id_b;
238
- }
239
- }
240
- }
241
- // when all else is equal, sort by pos
242
- if (batch.pos) {
243
- return batch.pos[a] < batch.pos[b];
244
- }
245
- // no pos, sort by id
246
- return a < b;
247
- }
248
- // shared prompts go first
249
- return n_seq_a > n_seq_b;
250
- }
251
- );
252
-
253
- // init seq
254
- llama_sbatch_seq * last_seq = nullptr;
255
-
256
- for (size_t i = 0; i < n_tokens; ++i) {
257
- const size_t bi = ids[i];
258
- const int32_t n_seqs = batch.n_seq_id[bi];
259
- llama_seq_id * seq_ids = batch.seq_id[bi];
260
- if (last_seq != nullptr) {
261
- bool same = n_seqs == last_seq->n_seq_id;
262
- for (int32_t j = 0; same && j < n_seqs; ++j) {
263
- if (seq_ids[j] != last_seq->seq_id[j]) {
264
- same = false;
265
- }
422
+
423
+ used.clear();
424
+ used.resize(get_n_tokens(), false);
425
+
426
+ ubatches.clear();
427
+ }
428
+
429
+ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
430
+ // find the first unused token
431
+ uint32_t cur_idx = 0;
432
+ while (cur_idx < used.size() && used[cur_idx]) {
433
+ ++cur_idx;
434
+ }
435
+
436
+ // we are done
437
+ if (cur_idx >= used.size()) {
438
+ return {};
439
+ }
440
+
441
+ std::vector<int32_t> idxs;
442
+
443
+ while (true) {
444
+ idxs.push_back(cur_idx);
445
+
446
+ used[cur_idx] = true;
447
+
448
+ ++cur_idx;
449
+
450
+ if (cur_idx >= used.size()) {
451
+ break;
452
+ }
453
+
454
+ if (idxs.size() >= n_ubatch) {
455
+ break;
456
+ }
457
+ }
458
+
459
+ return ubatch_add(idxs, idxs.size(), false);
460
+ }
461
+
462
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
463
+ std::vector<seq_set_t> cur_seq_set;
464
+
465
+ // determine the non-overlapping sequence sets participating in this ubatch
466
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
467
+ if (used[i]) {
468
+ continue;
469
+ }
470
+
471
+ bool add = true;
472
+
473
+ for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
474
+ // no overlap with existing sequence sets:
475
+ if (!(cur_seq_set[s] & seq_set[i]).none()) {
476
+ add = false;
477
+ break;
266
478
  }
267
- if (same) {
268
- last_seq->length += 1;
269
- continue;
479
+ }
480
+
481
+ if (add) {
482
+ cur_seq_set.push_back(seq_set[i]);
483
+
484
+ if (cur_seq_set.size() > n_ubatch) {
485
+ break;
270
486
  }
271
487
  }
272
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
273
- seq.push_back(new_seq);
274
- last_seq = &seq.back();
275
488
  }
276
489
 
277
- // keep shared prompts first at the end, then sort by length descending.
278
- std::sort(seq.begin(), seq.end(),
279
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
280
- if (a.n_seq_id == b.n_seq_id) {
281
- return a.length > b.length;
282
- }
283
- return a.n_seq_id < b.n_seq_id;
490
+ const uint32_t n_seqs = cur_seq_set.size();
491
+
492
+ // we are done
493
+ if (n_seqs == 0) {
494
+ return {};
495
+ }
496
+
497
+ // the current batch index of each sequence set
498
+ std::vector<int32_t> cur_idx(n_seqs, 0);
499
+
500
+ for (uint32_t s = 0; s < n_seqs; ++s) {
501
+ while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
502
+ ++cur_idx[s];
503
+ }
504
+ }
505
+
506
+ // the list of batch indices for each sequence set
507
+ // at the end we will concat these to get the final ubatch
508
+ std::vector<idx_vec_t> idxs_per_seq(n_seqs);
509
+
510
+ while (true) {
511
+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
512
+ // if we haven't reached n_ubatch
513
+ bool can_expand = true;
514
+
515
+ for (uint32_t s = 0; s < n_seqs; ++s) {
516
+ if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
517
+ can_expand = false;
518
+ break;
284
519
  }
285
- );
520
+ }
521
+
522
+ if (!can_expand) {
523
+ break;
524
+ }
525
+
526
+ for (uint32_t s = 0; s < n_seqs; ++s) {
527
+ const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
528
+
529
+ idxs_per_seq[s].push_back(idx);
530
+
531
+ used[idx] = true;
532
+
533
+ ++cur_idx[s];
534
+ }
535
+
536
+ if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
537
+ break;
538
+ }
539
+ }
540
+
541
+ // concat the per-sequence-set lists
542
+ std::vector<int32_t> idxs;
543
+
544
+ for (uint32_t s = 0; s < n_seqs; ++s) {
545
+ idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
546
+ }
547
+
548
+ return ubatch_add(idxs, n_seqs, true);
286
549
  }
287
550
 
288
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
289
- batch = in_batch;
290
- GGML_ASSERT(batch.n_tokens > 0);
291
- if (!batch.pos) {
292
- assert(p0 >= 0);
293
- pos.resize(batch.n_tokens);
294
- for (int32_t i = 0; i < batch.n_tokens; i++) {
295
- pos[i] = p0 + i;
551
+ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
552
+ // find the first unused token
553
+ uint32_t cur_idx = 0;
554
+ while (cur_idx < used.size() && used[cur_idx]) {
555
+ ++cur_idx;
556
+ }
557
+
558
+ // we are done
559
+ if (cur_idx >= used.size()) {
560
+ return {};
561
+ }
562
+
563
+ // this is the starting sequence set
564
+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
565
+ auto cur_seq_set = seq_set[cur_idx];
566
+
567
+ std::vector<int32_t> idxs;
568
+
569
+ while (true) {
570
+ idxs.push_back(cur_idx);
571
+
572
+ used[cur_idx] = true;
573
+
574
+ if (idxs.size() >= n_ubatch) {
575
+ break;
296
576
  }
297
- batch.pos = pos.data();
577
+
578
+ do {
579
+ ++cur_idx;
580
+ } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
581
+
582
+ if (cur_idx == get_n_tokens()) {
583
+ break;
584
+ }
585
+
586
+ cur_seq_set = seq_set[cur_idx];
298
587
  }
299
- if (!batch.n_seq_id) {
300
- n_seq_id.resize(batch.n_tokens);
301
- for (int32_t i = 0; i < batch.n_tokens; i++) {
302
- n_seq_id[i] = seq_id_0.size();
588
+
589
+ return ubatch_add(idxs, 1, true);
590
+ }
591
+
592
+ void llama_batch_allocr::clear() {
593
+ n_outputs = 0;
594
+
595
+ batch = {};
596
+
597
+ pos .clear();
598
+ n_seq_id .clear();
599
+ seq_id .clear();
600
+ seq_id_unq.clear();
601
+ output .clear();
602
+
603
+ for (auto & cur : seq_pos) {
604
+ cur.clear();
605
+ }
606
+
607
+ for (auto & cur : seq_cpl) {
608
+ std::fill(cur.begin(), cur.end(), false);
609
+ }
610
+
611
+ seq_set.clear();
612
+
613
+ seq_set_map.clear();
614
+
615
+ std::fill(seq_idx.begin(), seq_idx.end(), -1);
616
+ }
617
+
618
+ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
619
+ const uint32_t n_tokens = idxs.size();
620
+
621
+ assert(n_tokens%n_seqs == 0);
622
+
623
+ ubatches.emplace_back();
624
+
625
+ auto & ubatch = ubatches.back();
626
+
627
+ const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
628
+
629
+ const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
630
+ const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
631
+
632
+ ubatch.token .resize(n_tokens);
633
+ ubatch.embd .resize(n_embd_all);
634
+ ubatch.pos .resize(n_pos_all);
635
+ ubatch.n_seq_id .resize(n_tokens);
636
+ ubatch.seq_id .resize(n_tokens);
637
+ ubatch.seq_id_unq.resize(0);
638
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
639
+ ubatch.output .resize(n_tokens);
640
+
641
+ seq_set_t seq_set_unq;
642
+
643
+ for (size_t i = 0; i < idxs.size(); ++i) {
644
+ if (batch.token) {
645
+ ubatch.token[i] = batch.token[idxs[i]];
646
+ }
647
+
648
+ if (batch.embd) {
649
+ memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
650
+ }
651
+
652
+ for (int j = 0; j < n_pos_cur; ++j) {
653
+ ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
654
+ }
655
+
656
+ ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
657
+ ubatch.seq_id[i] = batch.seq_id[idxs[i]];
658
+ ubatch.output[i] = batch.logits[idxs[i]];
659
+
660
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
661
+ seq_set_unq.set(ubatch.seq_id[i][s]);
662
+ }
663
+
664
+ if (ubatch.output[i]) {
665
+ out_ids.push_back(idxs[i]);
303
666
  }
304
- batch.n_seq_id = n_seq_id.data();
305
667
  }
306
- if (!batch.seq_id) {
307
- seq_id.resize(batch.n_tokens + 1);
308
- seq_id[batch.n_tokens] = NULL;
309
- for (int32_t i = 0; i < batch.n_tokens; i++) {
310
- seq_id[i] = seq_id_0.data();
668
+
669
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
670
+ if (seq_set_unq.test(s)) {
671
+ ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
672
+ ubatch.seq_id_unq.push_back(s);
311
673
  }
312
- batch.seq_id = seq_id.data();
313
674
  }
314
- if (!batch.logits) {
315
- logits.resize(batch.n_tokens);
316
- logits[logits.size() - 1] = true;
317
- batch.logits = logits.data();
675
+
676
+ llama_ubatch res {
677
+ /*.equal_seqs =*/ equal_seqs,
678
+ /*.n_tokens =*/ n_tokens,
679
+ /*.n_seq_tokens =*/ n_tokens/n_seqs,
680
+ /*.n_seqs =*/ n_seqs,
681
+ /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
682
+
683
+ /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
684
+ /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
685
+ /*.pos =*/ ubatch.pos.data(),
686
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
687
+ /*.seq_id =*/ ubatch.seq_id.data(),
688
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
689
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
690
+ /*.output =*/ ubatch.output.data(),
691
+ };
692
+
693
+ if (debug > 0) {
694
+ LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
695
+
696
+ ubatch_print(res, debug);
697
+ }
698
+
699
+ return res;
700
+ }
701
+
702
+ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
703
+ if (debug > 0) {
704
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
705
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
706
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
707
+ LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
708
+ LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
709
+
710
+ std::stringstream ss_seq_id_unq;
711
+ std::stringstream ss_seq_idx;
712
+
713
+ ss_seq_id_unq << "[ ";
714
+ ss_seq_idx << "[";
715
+
716
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
717
+ ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
718
+ }
719
+
720
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
721
+ if (ubatch.seq_idx[s] >= 0) {
722
+ ss_seq_idx << ubatch.seq_idx[s]%10;
723
+ } else {
724
+ ss_seq_idx << ".";
725
+ }
726
+ }
727
+
728
+ ss_seq_id_unq << "]";
729
+ ss_seq_idx << "]";
730
+
731
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
732
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
733
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
734
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
735
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
736
+ LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
737
+ LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
738
+ LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
739
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
740
+
741
+ if (debug > 1) {
742
+ int seq_id_max = 0;
743
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
744
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
745
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
746
+ seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
747
+ }
748
+ }
749
+ }
750
+ ++seq_id_max;
751
+
752
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
753
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
754
+ std::vector<int8_t> seq_id(seq_id_max);
755
+
756
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
757
+ seq_id[ubatch.seq_id[i][s]] = 1;
758
+ }
759
+
760
+ std::stringstream ss;
761
+ for (int s = 0; s < seq_id_max; ++s) {
762
+ if (seq_id[s]) {
763
+ ss << s%10;
764
+ } else {
765
+ ss << ".";
766
+ }
767
+ }
768
+
769
+ if (ubatch.token) {
770
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
771
+ __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
772
+ ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
773
+ } else {
774
+ LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
775
+ __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
776
+ }
777
+ }
778
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
779
+ }
318
780
  }
319
781
  }
320
782
 
@@ -326,25 +788,25 @@ struct llama_batch llama_batch_get_one(
326
788
  llama_token * tokens,
327
789
  int32_t n_tokens) {
328
790
  return {
329
- /*n_tokens =*/ n_tokens,
330
- /*tokens =*/ tokens,
331
- /*embd =*/ nullptr,
332
- /*pos =*/ nullptr,
333
- /*n_seq_id =*/ nullptr,
334
- /*seq_id =*/ nullptr,
335
- /*logits =*/ nullptr,
791
+ /*n_tokens =*/ n_tokens,
792
+ /*tokens =*/ tokens,
793
+ /*embd =*/ nullptr,
794
+ /*pos =*/ nullptr,
795
+ /*n_seq_id =*/ nullptr,
796
+ /*seq_id =*/ nullptr,
797
+ /*logits =*/ nullptr,
336
798
  };
337
799
  }
338
800
 
339
801
  struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
340
802
  llama_batch batch = {
341
- /*n_tokens =*/ 0,
342
- /*tokens =*/ nullptr,
343
- /*embd =*/ nullptr,
344
- /*pos =*/ nullptr,
345
- /*n_seq_id =*/ nullptr,
346
- /*seq_id =*/ nullptr,
347
- /*logits =*/ nullptr,
803
+ /*n_tokens =*/ 0,
804
+ /*tokens =*/ nullptr,
805
+ /*embd =*/ nullptr,
806
+ /*pos =*/ nullptr,
807
+ /*n_seq_id =*/ nullptr,
808
+ /*seq_id =*/ nullptr,
809
+ /*logits =*/ nullptr,
348
810
  };
349
811
 
350
812
  if (embd) {