@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -49,14 +49,14 @@ public:
49
49
  // llama_memory_i
50
50
  //
51
51
 
52
- llama_memory_state_ptr init_batch(
53
- const llama_batch & batch,
52
+ llama_memory_context_ptr init_batch(
53
+ llama_batch_allocr & balloc,
54
54
  uint32_t n_ubatch,
55
- bool embd_pooled) override;
55
+ bool embd_all) override;
56
56
 
57
- llama_memory_state_ptr init_full() override;
57
+ llama_memory_context_ptr init_full() override;
58
58
 
59
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
59
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
60
60
 
61
61
  bool get_can_shift() const override;
62
62
 
@@ -90,54 +90,51 @@ private:
90
90
  const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
91
  };
92
92
 
93
- class llama_memory_hybrid_state : public llama_memory_state_i {
93
+ class llama_memory_hybrid_context : public llama_memory_context_i {
94
94
  public:
95
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
96
+
95
97
  // init failure
96
- explicit llama_memory_hybrid_state(llama_memory_status status);
98
+ explicit llama_memory_hybrid_context(llama_memory_status status);
97
99
 
98
100
  // init full
99
- explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
101
+ explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
100
102
 
101
103
  // init update
102
- explicit llama_memory_hybrid_state(
104
+ explicit llama_memory_hybrid_context(
103
105
  llama_memory_hybrid * mem,
104
106
  llama_context * lctx,
105
107
  bool optimize);
106
108
 
107
109
  // init success
108
- llama_memory_hybrid_state(
110
+ llama_memory_hybrid_context(
109
111
  llama_memory_hybrid * mem,
110
- llama_sbatch sbatch,
111
- std::vector<uint32_t> heads_attn,
112
+ slot_info_vec_t sinfos_attn,
112
113
  std::vector<llama_ubatch> ubatches);
113
114
 
114
- ~llama_memory_hybrid_state() = default;
115
+ ~llama_memory_hybrid_context() = default;
115
116
 
116
117
  bool next() override;
117
118
  bool apply() override;
118
119
 
119
- std::vector<int64_t> & out_ids() override;
120
-
121
120
  llama_memory_status get_status() const override;
122
121
  const llama_ubatch & get_ubatch() const override;
123
122
 
124
123
  //
125
- // llama_memory_hybrid_state
124
+ // llama_memory_hybrid_context
126
125
  //
127
126
 
128
- const llama_kv_cache_unified_state * get_state_attn() const;
129
- const llama_memory_recurrent_state * get_state_recr() const;
127
+ const llama_kv_cache_unified_context * get_attn() const;
128
+ const llama_memory_recurrent_context * get_recr() const;
130
129
 
131
130
  private:
132
- llama_sbatch sbatch;
133
-
134
131
  // the index of the next ubatch to process
135
132
  size_t i_next = 0;
136
133
 
137
134
  std::vector<llama_ubatch> ubatches;
138
135
 
139
- const llama_memory_state_ptr state_attn;
140
- const llama_memory_state_ptr state_recr;
136
+ const llama_memory_context_ptr ctx_attn;
137
+ const llama_memory_context_ptr ctx_recr;
141
138
 
142
139
  const llama_memory_status status;
143
140
  };
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
25
25
  uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
26
  const int32_t n_layer = hparams.n_layer;
27
27
 
28
- LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
- __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
30
-
31
28
  head = 0;
32
29
  size = mem_size;
33
30
  used = 0;
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
84
81
 
85
82
  ggml_context * ctx = ctx_for_buft(buft);
86
83
  if (!ctx) {
87
- throw std::runtime_error("failed to create ggml context for kv cache");
84
+ throw std::runtime_error("failed to create ggml context for rs cache");
88
85
  }
89
86
 
90
87
  ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
102
99
 
103
100
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
104
101
  if (!buf) {
105
- throw std::runtime_error("failed to allocate buffer for kv cache");
102
+ throw std::runtime_error("failed to allocate buffer for rs cache");
106
103
  }
107
104
  ggml_backend_buffer_clear(buf, 0);
108
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
105
+ LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
109
106
  bufs.emplace_back(buf);
110
107
  }
111
108
 
@@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
113
110
  const size_t memory_size_r = size_r_bytes();
114
111
  const size_t memory_size_s = size_s_bytes();
115
112
 
116
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
- (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
113
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
114
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
118
115
  ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
116
  ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
120
117
  }
@@ -362,40 +359,52 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
362
359
  return result;
363
360
  }
364
361
 
365
- llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
366
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
362
+ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
363
+ do {
364
+ balloc.split_reset();
367
365
 
368
- std::vector<llama_ubatch> ubatches;
366
+ std::vector<llama_ubatch> ubatches;
367
+ while (true) {
368
+ llama_ubatch ubatch;
369
369
 
370
- while (sbatch.n_tokens > 0) {
371
- llama_ubatch ubatch;
370
+ if (embd_all) {
371
+ // if all tokens are output, split by sequence
372
+ ubatch = balloc.split_seq(n_ubatch);
373
+ } else {
374
+ ubatch = balloc.split_equal(n_ubatch, false);
375
+ }
372
376
 
373
- if (embd_all) {
374
- // if all tokens are output, split by sequence
375
- ubatch = sbatch.split_seq(n_ubatch);
376
- } else {
377
- ubatch = sbatch.split_equal(n_ubatch);
377
+ if (ubatch.n_tokens == 0) {
378
+ break;
379
+ }
380
+
381
+ ubatches.push_back(std::move(ubatch)); // NOLINT
378
382
  }
379
383
 
380
- ubatches.push_back(ubatch);
381
- }
384
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
385
+ // failed to find a suitable split
386
+ break;
387
+ }
382
388
 
383
- if (!prepare(ubatches)) {
384
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
385
- }
389
+ if (!prepare(ubatches)) {
390
+ break;
391
+ }
386
392
 
387
- return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
393
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
394
+ } while (false);
395
+
396
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
388
397
  }
389
398
 
390
- llama_memory_state_ptr llama_memory_recurrent::init_full() {
391
- return std::make_unique<llama_memory_recurrent_state>(this);
399
+ llama_memory_context_ptr llama_memory_recurrent::init_full() {
400
+ return std::make_unique<llama_memory_recurrent_context>(this);
392
401
  }
393
402
 
394
- llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
403
+ llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
395
404
  GGML_UNUSED(lctx);
396
405
  GGML_UNUSED(optimize);
397
406
 
398
- return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
407
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
399
408
  }
400
409
 
401
410
  bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -423,9 +432,8 @@ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches)
423
432
  }
424
433
 
425
434
  bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
426
- const uint32_t n_seqs = ubatch.n_seqs;
427
-
428
435
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
436
+ const uint32_t n_seqs = ubatch.n_seqs;
429
437
 
430
438
  // if we have enough unused cells before the current head ->
431
439
  // better to start searching from the beginning of the cache, hoping to fill it
@@ -445,9 +453,11 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
445
453
 
446
454
  // everything should fit if all seq_ids are smaller than the max
447
455
  for (uint32_t s = 0; s < n_seqs; ++s) {
448
- const uint32_t n_seq_id = ubatch.n_seq_id[s];
456
+ const uint32_t i = s*n_seq_tokens; // first token of sequence set s
457
+ const uint32_t n_seq_id = ubatch.n_seq_id[i];
458
+
449
459
  for (uint32_t j = 0; j < n_seq_id; ++j) {
450
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
460
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
451
461
 
452
462
  if (seq_id < 0 || (uint32_t) seq_id >= size) {
453
463
  // too big seq_id
@@ -506,7 +516,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
506
516
 
507
517
  // find usable cell range
508
518
  for (uint32_t s = 0; s < n_seqs; ++s) {
509
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
519
+ const uint32_t i = s*n_seq_tokens;
520
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
510
521
  auto & seq_meta = cells[seq_id];
511
522
  bool has_cell = false;
512
523
  if (seq_meta.tail >= 0) {
@@ -530,7 +541,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
530
541
  seq_meta.tail = next_empty_cell;
531
542
  // find next empty cell
532
543
  if (s + 1 < n_seqs) {
533
- for (uint32_t i = 0; i < size; ++i) {
544
+ for (uint32_t j = 0; j < size; ++j) {
534
545
  next_empty_cell += 1;
535
546
  if (next_empty_cell >= size) { next_empty_cell -= size; }
536
547
  auto & cell = cells[next_empty_cell];
@@ -544,8 +555,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
544
555
 
545
556
  // gather and re-order
546
557
  for (uint32_t s = 0; s < n_seqs; ++s) {
558
+ const uint32_t i = s*n_seq_tokens;
547
559
  const int32_t dst_id = s + min;
548
- const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
560
+ const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
549
561
  if (dst_id != src_id) {
550
562
  auto & dst_cell = cells[dst_id];
551
563
  auto & src_cell = cells[src_id];
@@ -555,8 +567,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
555
567
  std::swap(dst_cell.seq_id, src_cell.seq_id);
556
568
 
557
569
  // swap tails
558
- for (uint32_t i = 0; i < size; ++i) {
559
- int32_t & tail = cells[i].tail;
570
+ for (uint32_t j = 0; j < size; ++j) {
571
+ int32_t & tail = cells[j].tail;
560
572
  if (tail == src_id) {
561
573
  tail = dst_id;
562
574
  } else if (tail == dst_id) {
@@ -568,7 +580,8 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
568
580
 
569
581
  // update the pos of the used seqs
570
582
  for (uint32_t s = 0; s < n_seqs; ++s) {
571
- const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
583
+ const uint32_t i = s*n_seq_tokens;
584
+ const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
572
585
  const int32_t cell_id = s + min;
573
586
  auto & cell = cells[cell_id];
574
587
 
@@ -576,12 +589,12 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
576
589
  // What should happen when the pos backtracks or skips a value?
577
590
  // Clearing the state mid-batch would require special-casing which isn't done.
578
591
  LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
579
- __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
592
+ __func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
580
593
  }
581
594
  cell.pos = last_pos;
582
595
  cell.seq_id.clear();
583
- for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
584
- const llama_seq_id seq_id = ubatch.seq_id[s][j];
596
+ for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
597
+ const llama_seq_id seq_id = ubatch.seq_id[i][j];
585
598
  cell.seq_id.insert(seq_id);
586
599
  cells[seq_id].tail = cell_id;
587
600
  }
@@ -827,12 +840,9 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
827
840
 
828
841
  seq_rm(dest_seq_id, -1, -1);
829
842
 
830
- llama_sbatch sbatch;
831
- llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
843
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
832
844
 
833
- batch.n_tokens = cell_count;
834
- batch.n_seq_tokens = cell_count;
835
- batch.n_seqs = 1;
845
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
836
846
 
837
847
  for (uint32_t i = 0; i < cell_count; ++i) {
838
848
  llama_pos pos;
@@ -846,12 +856,12 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
846
856
  return false;
847
857
  }
848
858
 
849
- batch.pos[i] = pos;
859
+ ubatch.pos[i] = pos;
850
860
  }
851
- batch.n_seq_id[0] = 1;
852
- batch.seq_id[0] = &dest_seq_id;
861
+ ubatch.n_seq_id[0] = 1;
862
+ ubatch.seq_id[0] = &dest_seq_id;
853
863
 
854
- if (!find_slot(batch)) {
864
+ if (!find_slot(ubatch)) {
855
865
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
856
866
  return false;
857
867
  }
@@ -859,8 +869,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
859
869
  // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
860
870
  // Assume that this is one contiguous block of cells
861
871
  GGML_ASSERT(head + cell_count <= size);
862
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
863
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
872
+ GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
873
+ GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
864
874
  GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
865
875
  GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
866
876
  } else {
@@ -1037,23 +1047,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
1037
1047
  }
1038
1048
 
1039
1049
  //
1040
- // llama_memory_recurrent_state
1050
+ // llama_memory_recurrent_context
1041
1051
  //
1042
1052
 
1043
- llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
1053
+ llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
1044
1054
 
1045
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1055
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1046
1056
  llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1047
1057
  }
1048
1058
 
1049
- llama_memory_recurrent_state::llama_memory_recurrent_state(
1059
+ llama_memory_recurrent_context::llama_memory_recurrent_context(
1050
1060
  llama_memory_recurrent * mem,
1051
- llama_sbatch sbatch,
1052
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1061
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
1053
1062
 
1054
- llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
1063
+ llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
1055
1064
 
1056
- bool llama_memory_recurrent_state::next() {
1065
+ bool llama_memory_recurrent_context::next() {
1057
1066
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1058
1067
 
1059
1068
  if (++i_next >= ubatches.size()) {
@@ -1063,54 +1072,56 @@ bool llama_memory_recurrent_state::next() {
1063
1072
  return true;
1064
1073
  }
1065
1074
 
1066
- bool llama_memory_recurrent_state::apply() {
1067
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1075
+ bool llama_memory_recurrent_context::apply() {
1076
+ assert(!llama_memory_status_is_fail(status));
1068
1077
 
1069
- mem->find_slot(ubatches[i_next]);
1078
+ // no ubatches -> this is an update
1079
+ if (ubatches.empty()) {
1080
+ // recurrent cache never performs updates
1081
+ assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
1070
1082
 
1071
- return true;
1072
- }
1083
+ return true;
1084
+ }
1073
1085
 
1074
- std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
1075
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1086
+ mem->find_slot(ubatches[i_next]);
1076
1087
 
1077
- return sbatch.out_ids;
1088
+ return true;
1078
1089
  }
1079
1090
 
1080
- llama_memory_status llama_memory_recurrent_state::get_status() const {
1091
+ llama_memory_status llama_memory_recurrent_context::get_status() const {
1081
1092
  return status;
1082
1093
  }
1083
1094
 
1084
- const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
1095
+ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
1085
1096
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1086
1097
 
1087
1098
  return ubatches[i_next];
1088
1099
  }
1089
1100
 
1090
- uint32_t llama_memory_recurrent_state::get_n_rs() const {
1101
+ uint32_t llama_memory_recurrent_context::get_n_rs() const {
1091
1102
  return is_full ? mem->size : mem->n;
1092
1103
  }
1093
1104
 
1094
- uint32_t llama_memory_recurrent_state::get_head() const {
1105
+ uint32_t llama_memory_recurrent_context::get_head() const {
1095
1106
  return is_full ? 0 : mem->head;
1096
1107
  }
1097
1108
 
1098
- int32_t llama_memory_recurrent_state::get_rs_z() const {
1109
+ int32_t llama_memory_recurrent_context::get_rs_z() const {
1099
1110
  return is_full ? 0 : mem->rs_z;
1100
1111
  }
1101
1112
 
1102
- uint32_t llama_memory_recurrent_state::get_size() const {
1113
+ uint32_t llama_memory_recurrent_context::get_size() const {
1103
1114
  return mem->size;
1104
1115
  }
1105
1116
 
1106
- ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
1117
+ ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
1107
1118
  return mem->r_l[il];
1108
1119
  }
1109
1120
 
1110
- ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
1121
+ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
1111
1122
  return mem->s_l[il];
1112
1123
  }
1113
1124
 
1114
- int32_t llama_memory_recurrent_state::s_copy(int i) const {
1125
+ int32_t llama_memory_recurrent_context::s_copy(int i) const {
1115
1126
  return mem->cells[i + mem->head].src0;
1116
1127
  }
@@ -11,8 +11,8 @@
11
11
  // llama_memory_recurrent
12
12
  //
13
13
 
14
- // TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
15
- // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
14
+ // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
15
+ // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
16
16
  class llama_memory_recurrent : public llama_memory_i {
17
17
  public:
18
18
 
@@ -34,14 +34,14 @@ public:
34
34
  // llama_memory_i
35
35
  //
36
36
 
37
- llama_memory_state_ptr init_batch(
38
- const llama_batch & batch,
37
+ llama_memory_context_ptr init_batch(
38
+ llama_batch_allocr & balloc,
39
39
  uint32_t n_ubatch,
40
40
  bool embd_all) override;
41
41
 
42
- llama_memory_state_ptr init_full() override;
42
+ llama_memory_context_ptr init_full() override;
43
43
 
44
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
44
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
45
45
 
46
46
  void clear(bool data) override;
47
47
 
@@ -125,37 +125,34 @@ private:
125
125
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
126
126
  };
127
127
 
128
- class llama_memory_recurrent_state : public llama_memory_state_i {
128
+ class llama_memory_recurrent_context : public llama_memory_context_i {
129
129
  public:
130
130
  // used for errors
131
- llama_memory_recurrent_state(llama_memory_status status);
131
+ llama_memory_recurrent_context(llama_memory_status status);
132
132
 
133
- // used to create a full-cache state
134
- llama_memory_recurrent_state(
133
+ // used to create a full-cache or update context
134
+ llama_memory_recurrent_context(
135
135
  llama_memory_recurrent * mem);
136
136
 
137
- // used to create a state from a batch
138
- llama_memory_recurrent_state(
137
+ // used to create a batch processing context from a batch
138
+ llama_memory_recurrent_context(
139
139
  llama_memory_recurrent * mem,
140
- llama_sbatch sbatch,
141
140
  std::vector<llama_ubatch> ubatches);
142
141
 
143
- virtual ~llama_memory_recurrent_state();
142
+ virtual ~llama_memory_recurrent_context();
144
143
 
145
144
  //
146
- // llama_memory_state_i
145
+ // llama_memory_context_i
147
146
  //
148
147
 
149
148
  bool next() override;
150
149
  bool apply() override;
151
150
 
152
- std::vector<int64_t> & out_ids() override;
153
-
154
151
  llama_memory_status get_status() const override;
155
152
  const llama_ubatch & get_ubatch() const override;
156
153
 
157
154
  //
158
- // llama_memory_recurrent_state specific API
155
+ // llama_memory_recurrent_context specific API
159
156
  //
160
157
 
161
158
  uint32_t get_n_rs() const;
@@ -173,8 +170,6 @@ private:
173
170
 
174
171
  llama_memory_recurrent * mem;
175
172
 
176
- llama_sbatch sbatch;
177
-
178
173
  size_t i_next = 0;
179
174
 
180
175
  std::vector<llama_ubatch> ubatches;
@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
40
40
  // if either status has an update, then the combined status has an update
41
41
  return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42
42
  }
43
+
44
+ bool llama_memory_status_is_fail(llama_memory_status status) {
45
+ switch (status) {
46
+ case LLAMA_MEMORY_STATUS_SUCCESS:
47
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
48
+ {
49
+ return false;
50
+ }
51
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
52
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
53
+ {
54
+ return true;
55
+ }
56
+ }
57
+
58
+ return false;
59
+ }
@@ -3,10 +3,11 @@
3
3
  #include "llama.h"
4
4
 
5
5
  #include <memory>
6
- #include <vector>
7
6
 
8
7
  struct llama_ubatch;
9
8
 
9
+ class llama_batch_allocr;
10
+
10
11
  class llama_io_write_i;
11
12
  class llama_io_read_i;
12
13
 
@@ -26,23 +27,24 @@ enum llama_memory_status {
26
27
  LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
27
28
  };
28
29
 
29
- // helper function for combining the status of two memory states
30
+ // helper function for combining the status of two memory contexts
30
31
  // useful for implementing hybrid memory types (e.g. iSWA)
31
32
  llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
32
33
 
33
- // the interface for managing the memory state during batch processing
34
+ // helper function for checking if a memory status indicates a failure
35
+ bool llama_memory_status_is_fail(llama_memory_status status);
36
+
37
+ // the interface for managing the memory context during batch processing
34
38
  // this interface is implemented per memory type. see:
35
- // - llama_kv_cache_unified_state
36
- // - llama_kv_cache_unified_iswa_state
39
+ // - llama_kv_cache_unified_context
40
+ // - llama_kv_cache_unified_iswa_context
37
41
  // ...
38
42
  //
39
- // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
40
- //
41
- // TODO: rename to llama_memory_context_i ?
42
- struct llama_memory_state_i {
43
- virtual ~llama_memory_state_i() = default;
43
+ // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
44
+ struct llama_memory_context_i {
45
+ virtual ~llama_memory_context_i() = default;
44
46
 
45
- // consume the current ubatch from the state and proceed to the next one
47
+ // consume the current ubatch from the context and proceed to the next one
46
48
  // return false if we are done
47
49
  virtual bool next() = 0;
48
50
 
@@ -50,17 +52,14 @@ struct llama_memory_state_i {
50
52
  // return false on failure
51
53
  virtual bool apply() = 0;
52
54
 
53
- // TODO: this might get reworked in the future when refactoring llama_batch
54
- virtual std::vector<int64_t> & out_ids() = 0;
55
-
56
55
  // get the current ubatch
57
56
  virtual const llama_ubatch & get_ubatch() const = 0;
58
57
 
59
- // get the status of the memory state - used for error handling and checking if any updates would be applied
58
+ // get the status of the memory context - used for error handling and checking if any updates would be applied
60
59
  virtual llama_memory_status get_status() const = 0;
61
60
  };
62
61
 
63
- using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
62
+ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
64
63
 
65
64
  // general concept of LLM memory
66
65
  // the KV cache is a type of LLM memory, but there can be other types
@@ -68,19 +67,19 @@ struct llama_memory_i {
68
67
  virtual ~llama_memory_i() = default;
69
68
 
70
69
  // split the input batch into a set of ubatches and verify that they can fit into the cache
71
- // return a state object containing the ubatches and KV cache state required to process them
72
- // check the llama_memory_state_i::get_status() for the result
73
- virtual llama_memory_state_ptr init_batch(
74
- const llama_batch & batch,
70
+ // return a context object containing the ubatches and memory state required to process them
71
+ // check the llama_memory_context_i::get_status() for the result
72
+ virtual llama_memory_context_ptr init_batch(
73
+ llama_batch_allocr & balloc,
75
74
  uint32_t n_ubatch,
76
75
  bool embd_all) = 0;
77
76
 
78
77
  // simulate full cache, used for allocating worst-case compute buffers
79
- virtual llama_memory_state_ptr init_full() = 0;
78
+ virtual llama_memory_context_ptr init_full() = 0;
80
79
 
81
80
  // prepare for any pending memory updates, such as shifts, defrags, etc.
82
81
  // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
83
- virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
82
+ virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
84
83
 
85
84
  // getters
86
85
  virtual bool get_can_shift() const = 0;
@@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
228
228
  // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
229
229
  add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
230
230
  add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
231
+ add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
231
232
  add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
232
233
  add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
233
234
  add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());