@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -87,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
87
87
 
88
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
89
89
  if (pos_bucket) {
90
- kv_state->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
91
91
  }
92
92
  }
93
93
 
94
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
95
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
96
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ GGML_ASSERT(out_ids);
97
96
 
98
- if (!out_ids) {
99
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
100
- } else {
101
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
102
98
 
103
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
104
- int32_t * data = (int32_t *) out_ids->data;
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
105
101
 
106
- if (n_outputs == n_tokens) {
107
- for (int i = 0; i < n_tokens; ++i) {
108
- data[i] = i;
109
- }
110
- } else if (ubatch->output) {
111
- int32_t n_outputs = 0;
112
- for (int i = 0; i < n_tokens; ++i) {
113
- if (ubatch->output[i]) {
114
- data[n_outputs++] = i;
115
- }
116
- }
117
- // the graph needs to have been passed the correct number of outputs
118
- GGML_ASSERT(n_outputs == n_outputs);
119
- } else if (n_outputs == 1) {
120
- // only keep last output
121
- data[0] = n_tokens - 1;
122
- } else {
123
- GGML_ASSERT(n_outputs == 0);
124
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
125
117
  }
126
118
  }
127
119
  }
@@ -130,110 +122,97 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
130
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
131
123
  const int64_t n_tokens = ubatch->n_tokens;
132
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
133
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
134
126
 
135
127
  GGML_ASSERT(mean);
136
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
137
129
 
138
130
  float * data = (float *) mean->data;
139
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
140
-
141
- std::vector<uint64_t> sum(n_tokens, 0);
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
142
132
 
143
- // TODO: fix indexing [UBATCH_IDX]
144
- for (int s = 0; s < n_seqs; ++s) {
145
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
146
138
 
147
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
148
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
149
-
150
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
151
141
  }
152
142
 
153
- std::vector<float> div(n_tokens, 0.0f);
154
- for (int i = 0; i < n_tokens; ++i) {
155
- const uint64_t s = sum[i];
156
- if (s > 0) {
157
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
158
148
  }
159
149
  }
160
150
 
161
- // TODO: fix indexing [UBATCH_IDX]
162
- for (int s = 0; s < n_seqs; ++s) {
163
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
164
155
 
165
- for (int i = 0; i < n_seq_tokens; ++i) {
166
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
167
159
  }
168
160
  }
169
161
  }
170
162
  }
171
163
 
172
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
173
- if (cparams.embeddings && (
174
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
175
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
176
- const int64_t n_tokens = ubatch->n_tokens;
177
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
178
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
179
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
180
173
  GGML_ASSERT(cls);
181
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
182
175
 
183
176
  uint32_t * data = (uint32_t *) cls->data;
184
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
185
-
186
- // TODO: fix indexing [UBATCH_IDX]
187
- for (int s = 0; s < n_seqs; ++s) {
188
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
189
178
 
190
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
191
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
192
183
 
193
- for (int i = 0; i < n_seq_tokens; ++i) {
194
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
195
-
196
- if (pos == 0) {
197
- data[seq_id] = s*n_seq_tokens + i;
198
- }
184
+ data[seq_idx] = i;
199
185
  }
200
186
  }
201
187
  }
202
188
 
203
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
204
- const int64_t n_tokens = ubatch->n_tokens;
205
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
206
- const int64_t n_seqs = ubatch->n_seqs;
207
-
208
190
  GGML_ASSERT(cls);
209
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
210
192
 
211
193
  uint32_t * data = (uint32_t *) cls->data;
212
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
213
-
214
- std::vector<int> last_pos(n_tokens, -1);
215
- std::vector<int> last_row(n_tokens, -1);
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
216
195
 
217
- // TODO: fix indexing [UBATCH_IDX]
218
- for (int s = 0; s < n_seqs; ++s) {
219
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
220
198
 
221
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
222
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
223
201
 
224
- for (int i = 0; i < n_seq_tokens; ++i) {
225
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
226
205
 
227
- if (pos >= last_pos[seq_id]) {
228
- last_pos[seq_id] = pos;
229
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
230
209
  }
231
210
  }
232
211
  }
233
212
 
234
- for (int i = 0; i < n_tokens; ++i) {
235
- if (last_row[i] >= 0) {
236
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
237
216
  }
238
217
  }
239
218
  }
@@ -242,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
242
221
  void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
243
222
  GGML_UNUSED(ubatch);
244
223
 
245
- const int64_t n_rs = mem_state->get_n_rs();
224
+ const int64_t n_rs = mctx->get_n_rs();
246
225
 
247
226
  if (s_copy) {
248
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -250,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
250
229
 
251
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
252
231
  for (uint32_t i = 0; i < n_rs; ++i) {
253
- data[i] = mem_state->s_copy(i);
232
+ data[i] = mctx->s_copy(i);
254
233
  }
255
234
  }
256
235
  }
@@ -266,160 +245,99 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
266
245
  }
267
246
 
268
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
269
- if (kq_mask) {
270
- if (cparams.causal_attn) {
271
- const int64_t n_kv = ubatch->n_tokens;
272
- const int64_t n_tokens = ubatch->n_tokens;
273
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
274
- const int64_t n_seqs = ubatch->n_seqs;
275
-
276
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
277
- float * data = (float *) kq_mask->data;
278
-
279
- for (int h = 0; h < 1; ++h) {
280
- for (int s1 = 0; s1 < n_seqs; ++s1) {
281
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
282
-
283
- for (int j = 0; j < n_seq_tokens; ++j) {
284
- const int32_t tj = s1*n_seq_tokens + j;
285
-
286
- for (int s0 = 0; s0 < n_seqs; ++s0) {
287
- for (int i = 0; i < n_seq_tokens; ++i) {
288
- const int32_t ti = s0*n_seq_tokens + i;
289
- float f = -INFINITY;
290
-
291
- // TODO: fix indexing [UBATCH_IDX]
292
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
293
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
294
- if (hparams.use_alibi) {
295
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
296
- } else {
297
- f = 0.0f;
298
- }
299
- break;
300
- }
301
- }
302
-
303
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
304
- }
305
- }
306
- }
307
- }
308
- }
309
- } else {
310
- const int64_t n_tokens = ubatch->n_tokens;
311
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
312
- const int64_t n_seqs = ubatch->n_seqs;
313
- const int64_t n_stride = ubatch->n_tokens;
314
-
315
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
316
-
317
- float * data = (float *) kq_mask->data;
318
-
319
- for (int h = 0; h < 1; ++h) {
320
- for (int s1 = 0; s1 < n_seqs; ++s1) {
321
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
322
-
323
- for (int j = 0; j < n_seq_tokens; ++j) {
324
- const int32_t tj = s1*n_seq_tokens + j;
325
-
326
- for (int s0 = 0; s0 < n_seqs; ++s0) {
327
- for (int i = 0; i < n_seq_tokens; ++i) {
328
- const int32_t ti = s0*n_seq_tokens + i;
329
- float f = -INFINITY;
330
-
331
- // TODO: fix indexing [UBATCH_IDX]
332
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
333
- if (ubatch->seq_id[s0][s] == seq_id) {
334
- if (hparams.use_alibi) {
335
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
336
- } else {
337
- f = 0.0f;
338
- }
339
- break;
340
- }
341
- }
342
-
343
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
344
- }
345
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
255
+
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
346
259
 
347
- for (int i = n_tokens; i < n_stride; ++i) {
348
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
262
+
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
349
272
  }
273
+ break;
350
274
  }
351
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
352
278
  }
353
279
  }
354
280
  }
355
281
  }
356
282
 
357
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
358
- if (self_kq_mask) {
359
- kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
360
- }
284
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
285
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
286
+
287
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
361
288
  }
362
289
 
363
290
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
364
- if (self_kq_mask) {
365
- kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
- }
291
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
367
293
 
368
- if (self_kq_mask_swa) {
369
- kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
370
- }
294
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295
+
296
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
297
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
298
+
299
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
371
300
  }
372
301
 
373
302
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
374
- if (cross_kq_mask) {
375
- const int64_t n_enc = cross_kq_mask->ne[0];
376
- const int64_t n_tokens = ubatch->n_tokens;
303
+ GGML_ASSERT(cross_kq_mask);
377
304
 
378
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
379
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
305
+ const int64_t n_enc = cross_kq_mask->ne[0];
306
+ const int64_t n_tokens = ubatch->n_tokens;
380
307
 
381
- float * data = (float *) cross_kq_mask->data;
308
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
309
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
382
310
 
383
- for (int h = 0; h < 1; ++h) {
384
- for (int j = 0; j < n_tokens; ++j) {
385
- for (int i = 0; i < n_enc; ++i) {
386
- float f = -INFINITY;
387
- // TODO: fix indexing [UBATCH_IDX]
388
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
389
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
390
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
391
- f = 0.0f;
392
- }
311
+ float * data = (float *) cross_kq_mask->data;
312
+
313
+ for (int h = 0; h < 1; ++h) {
314
+ for (int i = 0; i < n_tokens; ++i) {
315
+ for (int j = 0; j < n_enc; ++j) {
316
+ float f = -INFINITY;
317
+
318
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
319
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
320
+
321
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
322
+ f = 0.0f;
393
323
  }
394
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
395
324
  }
325
+
326
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
396
327
  }
328
+ }
397
329
 
398
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
399
- for (int j = 0; j < n_enc; ++j) {
400
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
401
- }
330
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
331
+ for (int j = 0; j < n_enc; ++j) {
332
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
402
333
  }
403
334
  }
404
335
  }
405
336
  }
406
337
 
407
338
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
408
- if (self_kq_mask) {
409
- mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410
- }
411
-
412
- const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
413
-
414
- if (s_copy) {
415
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
416
- int32_t * data = (int32_t *) s_copy->data;
417
-
418
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419
- for (uint32_t i = 0; i < n_rs; ++i) {
420
- data[i] = mem_state->get_state_recr()->s_copy(i);
421
- }
422
- }
339
+ inp_attn->set_input(ubatch);
340
+ inp_rs->set_input(ubatch);
423
341
  }
424
342
 
425
343
  //
@@ -461,16 +379,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
461
379
  backend_cpu (params.backend_cpu),
462
380
  cvec (params.cvec),
463
381
  loras (params.loras),
464
- mstate (params.mstate),
382
+ mctx (params.mctx),
465
383
  cross (params.cross),
466
384
  cb_func (params.cb),
467
385
  res (std::make_unique<llm_graph_result>()) {
468
386
  }
469
387
 
470
- int64_t llm_graph_context::n_pos_per_embd() const {
471
- return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
472
- }
473
-
474
388
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
475
389
  if (cb_func) {
476
390
  cb_func(ubatch, cur, name, il);
@@ -630,12 +544,20 @@ ggml_tensor * llm_graph_context::build_ffn(
630
544
 
631
545
  switch (type_op) {
632
546
  case LLM_FFN_SILU:
633
- {
547
+ if (gate && type_gate == LLM_FFN_PAR) {
548
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
549
+ cb(cur, "ffn_swiglu", il);
550
+ type_gate = LLM_FFN_SEQ;
551
+ } else {
634
552
  cur = ggml_silu(ctx0, cur);
635
553
  cb(cur, "ffn_silu", il);
636
554
  } break;
637
555
  case LLM_FFN_GELU:
638
- {
556
+ if (gate && type_gate == LLM_FFN_PAR) {
557
+ cur = ggml_geglu_split(ctx0, cur, tmp);
558
+ cb(cur, "ffn_geglu", il);
559
+ type_gate = LLM_FFN_SEQ;
560
+ } else {
639
561
  cur = ggml_gelu(ctx0, cur);
640
562
  cb(cur, "ffn_gelu", il);
641
563
  if (act_scales != NULL) {
@@ -644,7 +566,11 @@ ggml_tensor * llm_graph_context::build_ffn(
644
566
  }
645
567
  } break;
646
568
  case LLM_FFN_RELU:
647
- {
569
+ if (gate && type_gate == LLM_FFN_PAR) {
570
+ cur = ggml_reglu_split(ctx0, cur, tmp);
571
+ cb(cur, "ffn_reglu", il);
572
+ type_gate = LLM_FFN_SEQ;
573
+ } else {
648
574
  cur = ggml_relu(ctx0, cur);
649
575
  cb(cur, "ffn_relu", il);
650
576
  } break;
@@ -658,32 +584,19 @@ ggml_tensor * llm_graph_context::build_ffn(
658
584
  } break;
659
585
  case LLM_FFN_SWIGLU:
660
586
  {
661
- // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
662
- int64_t split_point = cur->ne[0] / 2;
663
- // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
664
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
665
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
666
-
667
- x0 = ggml_silu(ctx0, x0);
668
- cb(cur, "ffn_silu", il);
669
-
670
- cur = ggml_mul(ctx0, x0, x1);
671
- cb(cur, "ffn_mul", il);
587
+ cur = ggml_swiglu(ctx0, cur);
588
+ cb(cur, "ffn_swiglu", il);
672
589
  } break;
673
590
  case LLM_FFN_GEGLU:
674
591
  {
675
- // Split into two equal parts
676
- int64_t split_point = cur->ne[0] / 2;
677
- // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
678
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
679
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
680
-
681
- x0 = ggml_gelu(ctx0, x0);
682
- cb(x0, "ffn_gelu", il);
683
-
684
- cur = ggml_mul(ctx0, x0, x1);
592
+ cur = ggml_geglu(ctx0, cur);
685
593
  cb(cur, "ffn_geglu", il);
686
594
  } break;
595
+ case LLM_FFN_REGLU:
596
+ {
597
+ cur = ggml_reglu(ctx0, cur);
598
+ cb(cur, "ffn_reglu", il);
599
+ } break;
687
600
  }
688
601
 
689
602
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -813,12 +726,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
813
726
 
814
727
  switch (type_op) {
815
728
  case LLM_FFN_SILU:
816
- {
729
+ if (gate_exps) {
730
+ cur = ggml_swiglu_split(ctx0, cur, up);
731
+ cb(cur, "ffn_moe_swiglu", il);
732
+ } else {
817
733
  cur = ggml_silu(ctx0, cur);
818
734
  cb(cur, "ffn_moe_silu", il);
819
735
  } break;
820
736
  case LLM_FFN_GELU:
821
- {
737
+ if (gate_exps) {
738
+ cur = ggml_geglu_split(ctx0, cur, up);
739
+ cb(cur, "ffn_moe_geglu", il);
740
+ } else {
822
741
  cur = ggml_gelu(ctx0, cur);
823
742
  cb(cur, "ffn_moe_gelu", il);
824
743
  } break;
@@ -826,11 +745,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
826
745
  GGML_ABORT("fatal error");
827
746
  }
828
747
 
829
- if (gate_exps) {
830
- cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
831
- cb(cur, "ffn_moe_gate_par", il);
832
- }
833
-
834
748
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
835
749
  cb(experts, "ffn_moe_down", il);
836
750
 
@@ -915,11 +829,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
915
829
  }
916
830
 
917
831
  ggml_tensor * llm_graph_context::build_inp_pos() const {
918
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
832
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
919
833
 
920
834
  auto & cur = inp->pos;
921
835
 
922
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
836
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
923
837
  ggml_set_input(cur);
924
838
 
925
839
  res->add_input(std::move(inp));
@@ -942,6 +856,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
942
856
  }
943
857
 
944
858
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
859
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
860
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
861
+ // features that require constant topology such as pipline parallelism
862
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
863
+ //if (n_outputs < n_tokens) {
864
+ // return nullptr;
865
+ //}
866
+
945
867
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
946
868
 
947
869
  auto & cur = inp->out_ids;
@@ -959,7 +881,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
959
881
 
960
882
  auto & cur = inp->mean;
961
883
 
962
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
884
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
963
885
  ggml_set_input(cur);
964
886
 
965
887
  res->add_input(std::move(inp));
@@ -972,7 +894,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
972
894
 
973
895
  auto & cur = inp->cls;
974
896
 
975
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
897
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
976
898
  ggml_set_input(cur);
977
899
 
978
900
  res->add_input(std::move(inp));
@@ -1018,11 +940,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1018
940
  }
1019
941
 
1020
942
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1021
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
943
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1022
944
 
1023
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
945
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1024
946
 
1025
- const auto n_kv = kv_state->get_n_kv();
947
+ const auto n_kv = mctx_cur->get_n_kv();
1026
948
 
1027
949
  auto & cur = inp->pos_bucket;
1028
950
 
@@ -1049,33 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1049
971
  return pos_bias;
1050
972
  }
1051
973
 
1052
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053
- const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
1054
-
1055
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
1056
-
1057
- {
1058
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1059
-
1060
- const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
1061
-
1062
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1063
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1064
- ggml_set_input(inp->self_kq_mask);
1065
-
1066
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1067
- }
1068
-
1069
- {
1070
- const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1071
-
1072
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1073
- ggml_set_input(inp->s_copy);
1074
- }
1075
-
1076
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1077
- }
1078
-
1079
974
  ggml_tensor * llm_graph_context::build_attn_mha(
1080
975
  ggml_cgraph * gf,
1081
976
  ggml_tensor * q,
@@ -1197,8 +1092,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1197
1092
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1198
1093
 
1199
1094
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1200
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1201
- //cb(inp_kq_mask, "KQ_mask", -1);
1095
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1202
1096
  ggml_set_input(inp->kq_mask);
1203
1097
 
1204
1098
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1250,23 +1144,38 @@ ggml_tensor * llm_graph_context::build_attn(
1250
1144
  return cur;
1251
1145
  }
1252
1146
 
1253
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1254
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1147
+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1148
+ ggml_context * ctx0,
1149
+ const llama_ubatch & ubatch,
1150
+ const llama_hparams & hparams,
1151
+ const llama_cparams & cparams,
1152
+ const llama_kv_cache_unified_context * mctx_cur) {
1255
1153
 
1256
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1154
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1257
1155
 
1258
1156
  {
1259
1157
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1260
1158
 
1261
- const auto n_kv = kv_state->get_n_kv();
1159
+ const auto n_kv = mctx_cur->get_n_kv();
1160
+ const auto n_tokens = ubatch.n_tokens;
1262
1161
 
1263
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1264
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1162
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1163
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1164
+
1165
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1265
1166
  ggml_set_input(inp->self_kq_mask);
1266
1167
 
1267
1168
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1268
1169
  }
1269
1170
 
1171
+ return inp;
1172
+ }
1173
+
1174
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1175
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1176
+
1177
+ auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1178
+
1270
1179
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1271
1180
  }
1272
1181
 
@@ -1288,19 +1197,22 @@ ggml_tensor * llm_graph_context::build_attn(
1288
1197
  ggml_build_forward_expand(gf, k_cur);
1289
1198
  ggml_build_forward_expand(gf, v_cur);
1290
1199
 
1291
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1200
+ const auto * mctx_cur = inp->mctx;
1292
1201
 
1293
1202
  // store to KV cache
1294
1203
  {
1295
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1296
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1204
+ const auto & k_idxs = inp->get_k_idxs();
1205
+ const auto & v_idxs = inp->get_v_idxs();
1206
+
1207
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1208
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1297
1209
  }
1298
1210
 
1299
1211
  const auto & kq_mask = inp->get_kq_mask();
1300
1212
 
1301
1213
  ggml_tensor * q = q_cur;
1302
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1303
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1214
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1215
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1304
1216
 
1305
1217
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1306
1218
  cb(cur, "kqv_out", il);
@@ -1335,26 +1247,39 @@ ggml_tensor * llm_graph_context::build_attn(
1335
1247
  // these nodes are added to the graph together so that they are not reordered
1336
1248
  // by doing so, the number of splits in the graph is reduced
1337
1249
  ggml_build_forward_expand(gf, q_cur);
1338
- ggml_build_forward_expand(gf, k_cur);
1339
- ggml_build_forward_expand(gf, v_cur);
1340
1250
 
1341
- const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1251
+ if (k_cur) {
1252
+ ggml_build_forward_expand(gf, k_cur);
1253
+ }
1254
+
1255
+ if (v_cur) {
1256
+ ggml_build_forward_expand(gf, v_cur);
1257
+ }
1258
+
1259
+ const auto * mctx_iswa = inp->mctx;
1342
1260
 
1343
1261
  const bool is_swa = hparams.is_swa(il);
1344
1262
 
1345
- const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1263
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1346
1264
 
1347
- // store to KV cache
1348
- {
1349
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1350
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1265
+ // optionally store to KV cache
1266
+ if (k_cur) {
1267
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1268
+
1269
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1270
+ }
1271
+
1272
+ if (v_cur) {
1273
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1274
+
1275
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1351
1276
  }
1352
1277
 
1353
1278
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1354
1279
 
1355
1280
  ggml_tensor * q = q_cur;
1356
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1357
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1281
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1282
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1358
1283
 
1359
1284
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1360
1285
  cb(cur, "kqv_out", il);
@@ -1379,7 +1304,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1379
1304
 
1380
1305
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1381
1306
 
1382
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1307
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1383
1308
  ggml_set_input(inp->cross_kq_mask);
1384
1309
 
1385
1310
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1429,66 +1354,21 @@ ggml_tensor * llm_graph_context::build_attn(
1429
1354
  return cur;
1430
1355
  }
1431
1356
 
1432
- ggml_tensor * llm_graph_context::build_attn(
1433
- llm_graph_input_mem_hybrid * inp,
1434
- ggml_cgraph * gf,
1435
- ggml_tensor * wo,
1436
- ggml_tensor * wo_b,
1437
- ggml_tensor * q_cur,
1438
- ggml_tensor * k_cur,
1439
- ggml_tensor * v_cur,
1440
- ggml_tensor * kq_b,
1441
- ggml_tensor * v_mla,
1442
- float kq_scale,
1443
- int il) const {
1444
- // these nodes are added to the graph together so that they are not reordered
1445
- // by doing so, the number of splits in the graph is reduced
1446
- ggml_build_forward_expand(gf, q_cur);
1447
- ggml_build_forward_expand(gf, k_cur);
1448
- ggml_build_forward_expand(gf, v_cur);
1449
-
1450
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1451
-
1452
- // store to KV cache
1453
- {
1454
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1455
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1456
- }
1457
-
1458
- const auto & kq_mask = inp->get_kq_mask();
1459
-
1460
- ggml_tensor * q = q_cur;
1461
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1462
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1463
-
1464
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1465
- cb(cur, "kqv_out", il);
1466
-
1467
- if (wo) {
1468
- cur = build_lora_mm(wo, cur);
1469
- if (arch == LLM_ARCH_GLM4) {
1470
- // GLM4 seems to have numerical issues with half-precision accumulators
1471
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1472
- }
1473
- }
1474
-
1475
- if (wo_b) {
1476
- cur = ggml_add(ctx0, cur, wo_b);
1477
- }
1478
-
1479
- return cur;
1480
- }
1481
-
1357
+ // TODO: maybe separate the inner implementation into a separate function
1358
+ // like with the non-sliding window equivalent
1359
+ // once sliding-window hybrid caches are a thing.
1482
1360
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1483
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1361
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1484
1362
 
1485
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1363
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1486
1364
 
1487
1365
  {
1488
- const auto n_kv = kv_state->get_base()->get_n_kv();
1366
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1489
1367
 
1490
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1491
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1368
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1369
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1370
+
1371
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1492
1372
  ggml_set_input(inp->self_kq_mask);
1493
1373
 
1494
1374
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1497,10 +1377,12 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1497
1377
  {
1498
1378
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1499
1379
 
1500
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1380
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1381
+
1382
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1383
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1501
1384
 
1502
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1503
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1385
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1504
1386
  ggml_set_input(inp->self_kq_mask_swa);
1505
1387
 
1506
1388
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1519,7 +1401,7 @@ ggml_tensor * llm_graph_context::build_rs(
1519
1401
  uint32_t kv_head,
1520
1402
  uint32_t kv_size,
1521
1403
  int32_t rs_zero,
1522
- bool avoid_copies) const {
1404
+ const llm_graph_get_rows_fn & get_state_rows) const {
1523
1405
 
1524
1406
  ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1525
1407
 
@@ -1528,19 +1410,11 @@ ggml_tensor * llm_graph_context::build_rs(
1528
1410
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1529
1411
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1530
1412
 
1531
- ggml_tensor * output_states;
1532
-
1533
- if (!avoid_copies) {
1534
- // copy states
1535
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1536
- // {state_size, kv_size} -> {state_size, n_seqs}
1537
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1538
- ggml_build_forward_expand(gf, output_states);
1539
- } else {
1540
- // FIXME: make the gathering operation happen before the copy below
1541
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1542
- output_states = states;
1543
- }
1413
+ // copy states
1414
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1415
+ // {state_size, kv_size} -> {state_size, n_seqs}
1416
+ ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1417
+ ggml_build_forward_expand(gf, output_states);
1544
1418
 
1545
1419
  // copy extra states which won't be changed further (between n_seqs and n_kv)
1546
1420
  ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@@ -1552,41 +1426,38 @@ ggml_tensor * llm_graph_context::build_rs(
1552
1426
  return output_states;
1553
1427
  }
1554
1428
 
1555
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1556
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1429
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1430
+ ggml_context * ctx0,
1431
+ const llama_memory_recurrent_context * mctx_cur) {
1557
1432
 
1558
- auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1433
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1559
1434
 
1560
- const auto n_rs = kv_state->get_n_rs();
1435
+ const auto n_rs = mctx_cur->get_n_rs();
1561
1436
 
1562
1437
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1563
1438
  ggml_set_input(inp->s_copy);
1564
1439
 
1565
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1440
+ return inp;
1566
1441
  }
1567
1442
 
1568
- ggml_tensor * llm_graph_context::build_rs(
1569
- llm_graph_input_rs * inp,
1570
- ggml_cgraph * gf,
1571
- ggml_tensor * s,
1572
- int32_t state_size,
1573
- int32_t n_seqs,
1574
- bool avoid_copies) const {
1575
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1443
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1444
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1445
+
1446
+ auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1576
1447
 
1577
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1448
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1578
1449
  }
1579
1450
 
1580
1451
  ggml_tensor * llm_graph_context::build_rs(
1581
- llm_graph_input_mem_hybrid * inp,
1452
+ llm_graph_input_rs * inp,
1582
1453
  ggml_cgraph * gf,
1583
1454
  ggml_tensor * s,
1584
1455
  int32_t state_size,
1585
1456
  int32_t n_seqs,
1586
- bool avoid_copies) const {
1587
- const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1457
+ const llm_graph_get_rows_fn & get_state_rows) const {
1458
+ const auto * kv_state = inp->mctx;
1588
1459
 
1589
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1460
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1590
1461
  }
1591
1462
 
1592
1463
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1594,13 +1465,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1594
1465
  ggml_cgraph * gf,
1595
1466
  const llama_ubatch & ubatch,
1596
1467
  int il) const {
1597
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1468
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1598
1469
 
1599
1470
  const auto token_shift_count = hparams.token_shift_count;
1600
1471
 
1601
1472
  const int64_t n_seqs = ubatch.n_seqs;
1602
1473
 
1603
- ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1474
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1604
1475
 
1605
1476
  ggml_tensor * token_shift = build_rs(
1606
1477
  inp, gf, token_shift_all,
@@ -1615,22 +1486,33 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1615
1486
  ggml_tensor * token_shift,
1616
1487
  const llama_ubatch & ubatch,
1617
1488
  int il) const {
1618
- const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1489
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1619
1490
 
1620
1491
  const auto token_shift_count = hparams.token_shift_count;
1621
1492
  const auto n_embd = hparams.n_embd;
1622
1493
 
1623
1494
  const int64_t n_seqs = ubatch.n_seqs;
1624
1495
 
1625
- const auto kv_head = kv_state->get_head();
1496
+ const auto kv_head = mctx_cur->get_head();
1626
1497
 
1627
1498
  return ggml_cpy(
1628
1499
  ctx0,
1629
1500
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1630
- ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1501
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1631
1502
  );
1632
1503
  }
1633
1504
 
1505
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1506
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1507
+
1508
+ auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1509
+ auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1510
+
1511
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1512
+
1513
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1514
+ }
1515
+
1634
1516
  void llm_graph_context::build_pooling(
1635
1517
  ggml_cgraph * gf,
1636
1518
  ggml_tensor * cls,