@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -20,7 +20,7 @@ llama_context::llama_context(
20
20
  const llama_model & model,
21
21
  llama_context_params params) :
22
22
  model(model),
23
- batch_allocr(std::make_unique<llama_batch_allocr>()) {
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
24
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
25
25
 
26
26
  t_start_us = model.t_start_us;
@@ -280,8 +280,8 @@ llama_context::llama_context(
280
280
 
281
281
  // simulate full KV cache
282
282
 
283
- const auto mstate = memory->init_full();
284
- if (!mstate) {
283
+ const auto mctx = memory->init_full();
284
+ if (!mctx) {
285
285
  throw std::runtime_error("failed to initialize KV cache");
286
286
  }
287
287
 
@@ -289,7 +289,7 @@ llama_context::llama_context(
289
289
 
290
290
  // reserve pp graph first so that buffers are only allocated once
291
291
  {
292
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
293
  if (!gf) {
294
294
  throw std::runtime_error("failed to allocate compute pp buffers");
295
295
  }
@@ -300,7 +300,7 @@ llama_context::llama_context(
300
300
 
301
301
  // reserve with tg graph to get the number of splits and nodes
302
302
  {
303
- auto * gf = graph_reserve(1, 1, 1, mstate.get());
303
+ auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
304
  if (!gf) {
305
305
  throw std::runtime_error("failed to allocate compute tg buffers");
306
306
  }
@@ -311,7 +311,7 @@ llama_context::llama_context(
311
311
 
312
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
313
313
  {
314
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
315
  if (!gf) {
316
316
  throw std::runtime_error("failed to allocate compute pp buffers");
317
317
  }
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
444
444
  optimize |= memory_force_optimize;
445
445
  memory_force_optimize = false;
446
446
 
447
- const auto mstate = memory->init_update(this, optimize);
448
- switch (mstate->get_status()) {
447
+ const auto mctx = memory->init_update(this, optimize);
448
+ switch (mctx->get_status()) {
449
449
  case LLAMA_MEMORY_STATUS_SUCCESS:
450
450
  {
451
451
  // noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
463
463
  }
464
464
  }
465
465
 
466
- if (!mstate->apply()) {
466
+ if (!mctx->apply()) {
467
467
  LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
468
  }
469
469
  }
470
470
 
471
471
  // if the memory module did any computation, we have to reserve a new worst-case graph
472
472
  {
473
- const auto mstate = memory->init_full();
474
- if (!mstate) {
475
- throw std::runtime_error("failed to initialize memory state");
473
+ const auto mctx = memory->init_full();
474
+ if (!mctx) {
475
+ throw std::runtime_error("failed to initialize memory context");
476
476
  }
477
477
 
478
478
  const uint32_t n_seqs = cparams.n_seq_max;
479
479
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480
480
 
481
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482
482
  if (!gf) {
483
483
  LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484
484
  }
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
678
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
679
679
  }
680
680
 
681
- llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
682
- if (mstate && !mstate->apply()) {
683
- LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682
+ if (mctx && !mctx->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
684
  ret = GGML_STATUS_FAILED;
685
685
  return nullptr;
686
686
  }
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
692
692
  return nullptr;
693
693
  }
694
694
 
695
- auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
696
  if (!res) {
697
697
  LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
698
  ret = GGML_STATUS_FAILED;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
722
722
  }
723
723
 
724
724
  int llama_context::encode(const llama_batch & batch_inp) {
725
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726
+
725
727
  if (batch_inp.n_tokens == 0) {
726
728
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
727
729
  return -1;
728
730
  }
729
731
 
732
+ const auto & hparams = model.hparams;
733
+
734
+ const int64_t n_embd = hparams.n_embd;
735
+
730
736
  // note: during encode, we always pass the full sequence starting from pos = 0
731
- if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
737
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
732
738
  LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
733
739
  return -1;
734
740
  }
735
741
 
736
- const llama_batch & batch = batch_allocr->get_batch();
742
+ const uint32_t n_tokens = balloc->get_n_tokens();
737
743
 
738
- const uint32_t n_tokens = batch.n_tokens;
739
-
740
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
744
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
741
745
 
742
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
743
747
  GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
751
755
 
752
756
  n_queued_tokens += n_tokens;
753
757
 
754
- const auto & hparams = model.hparams;
755
-
756
- const int64_t n_embd = hparams.n_embd;
757
-
758
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
759
-
760
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
761
-
762
758
  // reserve output buffer
763
759
  if (output_reserve(n_tokens) < n_tokens) {
764
760
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
817
813
  {
818
814
  // extract sequence embeddings
819
815
  auto & embd_seq_out = embd_seq;
820
- embd_seq_out.clear();
821
816
 
822
- GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
817
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
818
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
819
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
823
820
 
824
- // TODO: fix indexing [UBATCH_IDX]
825
- for (uint32_t i = 0; i < n_tokens; i++) {
826
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
827
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
828
- continue;
829
- }
830
821
  embd_seq_out[seq_id].resize(n_embd);
831
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
822
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
832
823
  }
833
824
  } break;
834
825
  case LLAMA_POOLING_TYPE_RANK:
835
826
  {
836
827
  // extract the rerank score - n_cls_out floats per sequence
837
828
  auto & embd_seq_out = embd_seq;
829
+
838
830
  const uint32_t n_cls_out = hparams.n_cls_out;
839
831
 
840
- // TODO: fix indexing [UBATCH_IDX]
841
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
842
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
843
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
844
- continue;
845
- }
832
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
835
+
846
836
  embd_seq_out[seq_id].resize(n_cls_out);
847
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
837
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
848
838
  }
849
839
  } break;
850
840
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
869
859
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
870
860
  memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
871
861
 
862
+ const auto & batch = balloc->get_batch();
863
+
872
864
  // remember the sequence ids used during the encoding - needed for cross attention later
873
865
  cross.seq_ids_enc.resize(n_tokens);
874
866
  for (uint32_t i = 0; i < n_tokens; i++) {
875
867
  cross.seq_ids_enc[i].clear();
868
+
876
869
  for (int s = 0; s < batch.n_seq_id[i]; s++) {
877
- llama_seq_id seq_id = batch.seq_id[i][s];
870
+ const llama_seq_id seq_id = batch.seq_id[i][s];
871
+
878
872
  cross.seq_ids_enc[i].insert(seq_id);
879
873
  }
880
874
  }
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
884
878
  }
885
879
 
886
880
  int llama_context::decode(const llama_batch & batch_inp) {
881
+ GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
+
887
883
  if (!memory) {
888
884
  LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
889
885
  return encode(batch_inp);
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
894
890
  return -1;
895
891
  }
896
892
 
897
- // when computing embeddings, all tokens are output
898
- const bool embd_all = cparams.embeddings;
899
-
900
- if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
901
- LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902
- return -1;
903
- }
904
-
905
- const llama_batch & batch = batch_allocr->get_batch();
906
-
907
893
  const auto & vocab = model.vocab;
908
894
  const auto & hparams = model.hparams;
909
895
 
910
896
  const int32_t n_vocab = vocab.n_tokens();
911
897
  const int64_t n_embd = hparams.n_embd;
912
898
 
913
- const uint32_t n_tokens_all = batch.n_tokens;
899
+ // when computing embeddings, all tokens are output
900
+ const bool output_all = cparams.embeddings;
914
901
 
915
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
902
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
+ return -1;
905
+ }
916
906
 
917
- const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
907
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
908
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
918
909
 
919
- if (embd_all) {
910
+ if (output_all) {
920
911
  // require that all tokens are output
921
912
  if (n_outputs_all != n_tokens_all) {
922
913
  LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -942,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
942
933
  // handle any pending defrags/shifts
943
934
  kv_self_update(false);
944
935
 
945
- llama_memory_state_ptr mstate;
936
+ llama_memory_context_ptr mctx;
946
937
 
947
938
  while (true) {
948
- mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
949
- if (!mstate) {
939
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
+ if (!mctx) {
950
941
  return -2;
951
942
  }
952
943
 
953
- switch (mstate->get_status()) {
944
+ switch (mctx->get_status()) {
954
945
  case LLAMA_MEMORY_STATUS_SUCCESS:
955
946
  {
956
947
  } break;
957
948
  case LLAMA_MEMORY_STATUS_NO_UPDATE:
958
949
  {
959
- LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
950
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
960
951
 
961
952
  return -2;
962
953
  }
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
966
957
  did_optimize = true;
967
958
 
968
959
  if (kv_self_update(true)) {
969
- LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
960
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
970
961
 
971
962
  continue;
972
963
  }
973
964
  }
974
965
 
975
- LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
966
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
976
967
 
977
968
  return 1;
978
969
  }
979
970
  case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
980
971
  {
981
- LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
972
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
982
973
 
983
974
  return -2;
984
975
  }
@@ -996,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
996
987
  int64_t n_outputs_prev = 0;
997
988
 
998
989
  do {
999
- const auto & ubatch = mstate->get_ubatch();
990
+ const auto & ubatch = mctx->get_ubatch();
1000
991
 
1001
992
  // count the outputs in this ubatch
1002
993
  {
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1005
996
  if (n_outputs_all == n_tokens_all) {
1006
997
  n_outputs_new = ubatch.n_tokens;
1007
998
  } else {
1008
- GGML_ASSERT(ubatch.output);
1009
999
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1010
1000
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1011
1001
  }
@@ -1019,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1019
1009
  ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1020
1010
 
1021
1011
  ggml_status status;
1022
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
1012
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1023
1013
 
1024
1014
  if (!res) {
1025
1015
  // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1028,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
1028
1018
  pos_min[s] = std::numeric_limits<llama_pos>::max();
1029
1019
  }
1030
1020
 
1031
- // TODO: fix sequence indexing
1032
1021
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1033
1022
  const auto & seq_id = ubatch.seq_id[i][0];
1034
1023
 
@@ -1105,27 +1094,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
1105
1094
  // extract sequence embeddings (cleared before processing each batch)
1106
1095
  auto & embd_seq_out = embd_seq;
1107
1096
 
1108
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1109
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1110
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1111
- continue;
1112
- }
1097
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1098
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1099
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1100
+
1113
1101
  embd_seq_out[seq_id].resize(n_embd);
1114
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1102
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1115
1103
  }
1116
1104
  } break;
1117
1105
  case LLAMA_POOLING_TYPE_RANK:
1118
1106
  {
1119
- // extract the rerank score - a single float per sequence
1107
+ // extract the rerank score - n_cls_out floats per sequence
1120
1108
  auto & embd_seq_out = embd_seq;
1121
1109
 
1122
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1123
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1124
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1125
- continue;
1126
- }
1127
- embd_seq_out[seq_id].resize(1);
1128
- ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1110
+ const uint32_t n_cls_out = hparams.n_cls_out;
1111
+
1112
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1113
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1114
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1115
+
1116
+ embd_seq_out[seq_id].resize(n_cls_out);
1117
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1129
1118
  }
1130
1119
  } break;
1131
1120
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1136,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1136
1125
  }
1137
1126
 
1138
1127
  n_outputs_prev += n_outputs;
1139
- } while (mstate->next());
1128
+ } while (mctx->next());
1140
1129
 
1141
1130
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1142
1131
  n_outputs = n_outputs_all;
@@ -1145,7 +1134,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
1145
1134
  if (n_outputs > 0) {
1146
1135
  bool sorted_output = true;
1147
1136
 
1148
- auto & out_ids = mstate->out_ids();
1137
+ auto & out_ids = balloc->get_out_ids();
1149
1138
 
1150
1139
  GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1151
1140
 
@@ -1302,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
1302
1291
  return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1303
1292
  }
1304
1293
 
1305
- ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1294
+ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1306
1295
  LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1307
1296
 
1308
1297
  if (n_tokens % n_seqs != 0) {
@@ -1318,11 +1307,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1318
1307
 
1319
1308
  this->n_outputs = n_outputs;
1320
1309
 
1321
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1322
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1310
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1323
1312
 
1324
1313
  auto * gf = graph_init();
1325
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1314
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1326
1315
 
1327
1316
  this->n_outputs = save_n_outputs;
1328
1317
 
@@ -1343,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
1343
1332
  }
1344
1333
 
1345
1334
  llm_graph_result_ptr llama_context::graph_build(
1346
- ggml_context * ctx,
1347
- ggml_cgraph * gf,
1348
- const llama_ubatch & ubatch,
1349
- llm_graph_type gtype,
1350
- const llama_memory_state_i * mstate) {
1335
+ ggml_context * ctx,
1336
+ ggml_cgraph * gf,
1337
+ const llama_ubatch & ubatch,
1338
+ llm_graph_type gtype,
1339
+ const llama_memory_context_i * mctx) {
1351
1340
  return model.build_graph(
1352
1341
  {
1353
1342
  /*.ctx =*/ ctx,
@@ -1359,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
1359
1348
  /*.backend_cpu =*/ backend_cpu,
1360
1349
  /*.cvec =*/ &cvec,
1361
1350
  /*.loras =*/ &loras,
1362
- /*.mstate =*/ mstate,
1351
+ /*.mctx =*/ mctx,
1363
1352
  /*.cross =*/ &cross,
1364
1353
  /*.n_outputs =*/ n_outputs,
1365
1354
  /*.cb =*/ graph_get_cb(),
@@ -2039,7 +2028,12 @@ void llama_context::opt_epoch_iter(
2039
2028
  batch.logits [pos_batch] = true;
2040
2029
  }
2041
2030
 
2042
- const auto n_tokens_all = batch.n_tokens;
2031
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2032
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
+ return;
2034
+ }
2035
+
2036
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
2043
2037
 
2044
2038
  n_queued_tokens += n_tokens_all;
2045
2039
 
@@ -2047,8 +2041,8 @@ void llama_context::opt_epoch_iter(
2047
2041
 
2048
2042
  uint32_t n_outputs_all = n_tokens_all;
2049
2043
 
2050
- auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
2051
- if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2044
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2045
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2052
2046
  LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2053
2047
  break;
2054
2048
  }
@@ -2061,17 +2055,17 @@ void llama_context::opt_epoch_iter(
2061
2055
 
2062
2056
  uint32_t pos_batch = 0;
2063
2057
  do {
2064
- const auto & ubatch = mstate->get_ubatch();
2058
+ const auto & ubatch = mctx->get_ubatch();
2065
2059
 
2066
2060
  n_outputs = ubatch.n_tokens;
2067
2061
 
2068
- if (!mstate->apply()) {
2069
- LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2062
+ if (!mctx->apply()) {
2063
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2070
2064
  break;
2071
2065
  }
2072
2066
 
2073
2067
  auto * gf = graph_init();
2074
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
2068
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2075
2069
 
2076
2070
  struct ggml_context * ctx_compute_opt;
2077
2071
  {
@@ -2106,7 +2100,7 @@ void llama_context::opt_epoch_iter(
2106
2100
  ggml_free(ctx_compute_opt);
2107
2101
 
2108
2102
  pos_batch += ubatch.n_tokens;
2109
- } while (mstate->next());
2103
+ } while (mctx->next());
2110
2104
  }
2111
2105
  }
2112
2106
 
@@ -18,7 +18,7 @@ class llama_io_read_i;
18
18
  class llama_io_write_i;
19
19
 
20
20
  struct llama_memory_i;
21
- struct llama_memory_state_i;
21
+ struct llama_memory_context_i;
22
22
 
23
23
  struct llama_context {
24
24
  // init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
93
93
  int32_t il_end);
94
94
 
95
95
  // process a single ubatch with a specific graph type
96
- // if memory_state is provided, it will be applied first to the context's memory
96
+ // if memory_context is provided, it will be applied first to the context's memory
97
97
  // ret contains the status of the graph computation
98
98
  // returns nullptr only if ret != GGML_STATUS_SUCCESS
99
99
  llm_graph_result_ptr process_ubatch(
100
- const llama_ubatch & ubatch,
101
- llm_graph_type gtype,
102
- llama_memory_state_i * mstate,
103
- ggml_status & ret);
100
+ const llama_ubatch & ubatch,
101
+ llm_graph_type gtype,
102
+ llama_memory_context_i * mctx,
103
+ ggml_status & ret);
104
104
 
105
105
  int encode(const llama_batch & batch_inp);
106
106
  int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ public:
197
197
  ggml_status graph_compute(ggml_cgraph * gf, bool batched);
198
198
 
199
199
  // reserve a graph with a dummy ubatch of the specified size
200
- ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
200
+ ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201
201
 
202
202
  private:
203
203
  llm_graph_result_ptr graph_build(
204
- ggml_context * ctx,
205
- ggml_cgraph * gf,
206
- const llama_ubatch & ubatch,
207
- llm_graph_type gtype,
208
- const llama_memory_state_i * mstate);
204
+ ggml_context * ctx,
205
+ ggml_cgraph * gf,
206
+ const llama_ubatch & ubatch,
207
+ llm_graph_type gtype,
208
+ const llama_memory_context_i * mctx);
209
209
 
210
210
  llm_graph_cb graph_get_cb() const;
211
211
 
@@ -247,7 +247,7 @@ private:
247
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
248
248
 
249
249
  // reuse the batch_allocr to avoid unnecessary memory allocations
250
- std::unique_ptr<llama_batch_allocr> batch_allocr;
250
+ std::unique_ptr<llama_batch_allocr> balloc;
251
251
 
252
252
  uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253
253