@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
33
33
  static constexpr int nwarps_max = 4;
34
34
  static constexpr bool Q_in_reg = true;
35
35
  static constexpr int nstages_target = 2;
36
- static constexpr int nbatch_K2 = 32;
37
- static constexpr int nbatch_V2 = 32;
38
- static constexpr int nbatch_combine = 32;
36
+
37
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
38
+ return 32;
39
+ }
40
+
41
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
42
+ return 32;
43
+ }
44
+
45
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
46
+ return 32;
47
+ }
48
+
49
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
50
+ return 32;
51
+ }
52
+
53
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
54
+ return 32;
55
+ }
56
+
57
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
58
+ return 32;
59
+ }
39
60
  };
40
61
 
41
62
  template <>
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
44
65
  static constexpr int nwarps_max = 4;
45
66
  static constexpr bool Q_in_reg = true;
46
67
  static constexpr int nstages_target = 2;
47
- static constexpr int nbatch_K2 = 40;
48
- static constexpr int nbatch_V2 = 40;
49
- static constexpr int nbatch_combine = 40;
68
+
69
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
70
+ return 40;
71
+ }
72
+
73
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
74
+ return 40;
75
+ }
76
+
77
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
78
+ return 40;
79
+ }
80
+
81
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
82
+ return 40;
83
+ }
84
+
85
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
86
+ return 40;
87
+ }
88
+
89
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
90
+ return 40;
91
+ }
50
92
  };
51
93
 
52
94
  template <>
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
55
97
  static constexpr int nwarps_max = 4;
56
98
  static constexpr bool Q_in_reg = true;
57
99
  static constexpr int nstages_target = 2;
58
- static constexpr int nbatch_K2 = 48;
59
- static constexpr int nbatch_V2 = 48;
60
- static constexpr int nbatch_combine = 48;
100
+
101
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
102
+ return 48;
103
+ }
104
+
105
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
106
+ return 48;
107
+ }
108
+
109
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
110
+ return 48;
111
+ }
112
+
113
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
114
+ return 48;
115
+ }
116
+
117
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
118
+ return 48;
119
+ }
120
+
121
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
122
+ return 48;
123
+ }
61
124
  };
62
125
 
63
126
  template <>
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
66
129
  static constexpr int nwarps_max = 4;
67
130
  static constexpr bool Q_in_reg = true;
68
131
  static constexpr int nstages_target = 2;
69
- static constexpr int nbatch_K2 = 56;
70
- static constexpr int nbatch_V2 = 56;
71
- static constexpr int nbatch_combine = 56;
132
+
133
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
134
+ return 56;
135
+ }
136
+
137
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
138
+ return 56;
139
+ }
140
+
141
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
142
+ return 56;
143
+ }
144
+
145
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
146
+ return 56;
147
+ }
148
+
149
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
150
+ return 56;
151
+ }
152
+
153
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
154
+ return 56;
155
+ }
72
156
  };
73
157
 
74
158
  template <>
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
77
161
  static constexpr int nwarps_max = 4;
78
162
  static constexpr bool Q_in_reg = true;
79
163
  static constexpr int nstages_target = 2;
80
- static constexpr int nbatch_K2 = 64;
81
- static constexpr int nbatch_V2 = 64;
82
- static constexpr int nbatch_combine = 64;
164
+
165
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
166
+ return 64;
167
+ }
168
+
169
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
170
+ return 64;
171
+ }
172
+
173
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
174
+ return 64;
175
+ }
176
+
177
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
178
+ return 64;
179
+ }
180
+
181
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
182
+ return 64;
183
+ }
184
+
185
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
186
+ return 64;
187
+ }
83
188
  };
84
189
 
85
190
  template <>
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
88
193
  static constexpr int nwarps_max = 4;
89
194
  static constexpr bool Q_in_reg = true;
90
195
  static constexpr int nstages_target = 2;
91
- static constexpr int nbatch_K2 = 128;
92
- static constexpr int nbatch_V2 = 128;
93
- static constexpr int nbatch_combine = 128;
196
+
197
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
198
+ return 128;
199
+ }
200
+
201
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
202
+ return 128;
203
+ }
204
+
205
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
206
+ return 128;
207
+ }
208
+
209
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
210
+ return 128;
211
+ }
212
+
213
+ static int get_nbatch_combine_host(const int cc, const int ncols) {
214
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
215
+ return ncols <= 16 ? 128 : 64;
216
+ }
217
+ return 64;
218
+ }
219
+
220
+ static constexpr __device__ int get_nbatch_combine_device(int ncols) {
221
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
222
+ return ncols <= 16 ? 128 : 64;
223
+ #else
224
+ GGML_UNUSED(ncols);
225
+ return 128;
226
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
227
+ }
94
228
  };
95
229
 
96
230
  template <>
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
99
233
  static constexpr int nwarps_max = 8;
100
234
  static constexpr bool Q_in_reg = false;
101
235
  static constexpr int nstages_target = 1;
102
- static constexpr int nbatch_K2 = 160;
103
- static constexpr int nbatch_V2 = 128;
104
- static constexpr int nbatch_combine = 128;
236
+
237
+ static int get_nbatch_K2_host(const int cc, const int ncols) {
238
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
239
+ return ncols <= 16 ? 96 : 160;
240
+ }
241
+ return ncols <= 16 ? 288 : 160;
242
+ }
243
+
244
+ static constexpr __device__ int get_nbatch_K2_device(int ncols) {
245
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
246
+ return ncols <= 16 ? 96 : 160;
247
+ #else
248
+ return ncols <= 16 ? 288 : 160;
249
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
250
+ }
251
+
252
+ static int get_nbatch_V2_host(const int cc, const int ncols) {
253
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
254
+ return ncols <= 16 ? 64 : 128;
255
+ }
256
+ return ncols <= 16 ? 256 : 128;
257
+ }
258
+
259
+ static constexpr __device__ int get_nbatch_V2_device(int ncols) {
260
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
261
+ return ncols <= 16 ? 64 : 128;
262
+ #else
263
+ return ncols <= 16 ? 256 : 128;
264
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
265
+ }
266
+
267
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
268
+ return 128;
269
+ }
270
+
271
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
272
+ return 128;
273
+ }
105
274
  };
106
275
 
107
276
  // ------------------------------------------------------------------------------------------------------------------
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
120
289
 
121
290
  const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
122
291
 
123
- auto load = [&] __device__ (const int n) {
292
+ auto load = [&] __device__ (auto n) {
124
293
  const int stride_k = WARP_SIZE >> n;
125
294
  const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
126
295
  const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
223
392
  }
224
393
  }
225
394
 
226
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
395
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
227
396
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
228
397
  const float2 * const __restrict__ Q_f2,
229
398
  const half2 * const __restrict__ K_h2,
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
261
430
  constexpr int cols_per_warp = ntiles * tile_B::I;
262
431
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
263
432
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
433
+ constexpr int ncols = ncols1 * ncols2;
434
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
435
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
436
+
437
+ constexpr int stride_tile_Q = DKQ/2 + 4;
438
+ constexpr int stride_tile_K = nbatch_K2 + 4;
264
439
 
265
- constexpr int stride_tile_Q = DKQ/2 + 4;
266
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
267
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
440
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
441
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
268
442
 
269
443
  const int k_VKQ_0 = kb0 * c::nbatch_fa;
270
444
  tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
275
449
  tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
276
450
 
277
451
  if constexpr (nstages > 1) {
278
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
452
+ static_assert(!mla, "multi-stage loading not implemented for MLA");
453
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
279
454
  constexpr bool use_cp_async = true;
280
455
  cp_async_wait_all();
281
456
  __syncthreads();
282
457
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
283
- (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
458
+ (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
284
459
  } else {
285
460
  constexpr bool use_cp_async = nstages == 1;
286
461
  if (ncols2 > 1 || mask_h2) {
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
289
464
  }
290
465
 
291
466
  #pragma unroll
292
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
293
- const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
467
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
468
+ const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
294
469
  const int k0_diff = k0_stop - k0_start;
295
470
 
296
471
  if (nstages <= 1) {
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
537
712
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
538
713
  }
539
714
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
540
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
715
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
541
716
  }
542
717
  }
543
718
 
719
+
720
+ // For MLA K and V have the same data.
721
+ // Therefore, iterate over V in reverse and re-use the data if possible.
722
+ static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
723
+ constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
544
724
  #pragma unroll
545
- for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
546
- const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
547
- const int i0_diff = i0_stop - i0_start;
725
+ for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
726
+ const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
727
+ const int i0_diff = i0_stop - i0_start;
548
728
 
549
- if (nstages <= 1) {
729
+ if (nstages <= 1 && i0_start < reusable_cutoff) {
550
730
  constexpr bool use_cp_async = nstages == 1;
551
731
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
552
732
  (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
555
735
  }
556
736
  __syncthreads();
557
737
  }
738
+ const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
558
739
 
559
740
  // Calculate VKQ tile:
560
741
  #pragma unroll
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
565
746
  const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
566
747
 
567
748
  tile_A A;
568
- load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
749
+ load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
569
750
  if (ntiles == 1) {
570
751
  mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
571
752
  } else {
@@ -591,12 +772,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
591
772
  GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
592
773
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
593
774
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
594
- GGML_UNUSED(kb0);
775
+ GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
595
776
  NO_DEVICE_CODE;
596
777
  #endif // NEW_MMA_AVAILABLE
597
778
  }
598
779
 
599
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
780
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
600
781
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
601
782
  const float2 * const __restrict__ Q_f2,
602
783
  const half2 * const __restrict__ K_h2,
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
632
813
  constexpr int cols_per_warp = ntiles * tile_B::I;
633
814
  constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
634
815
  constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
816
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
817
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
635
818
 
636
819
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
637
820
 
638
- constexpr int stride_tile_Q = DKQ/2 + 4;
639
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
640
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
821
+ constexpr int stride_tile_Q = DKQ/2 + 4;
822
+ constexpr int stride_tile_K = nbatch_K2 + 4;
641
823
 
824
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
825
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
642
826
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
643
827
 
644
828
  extern __shared__ half2 tile_Q[];
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
726
910
 
727
911
  // Preload mask and K data for first iteration when using cp_async with multiple stages:
728
912
  if constexpr (nstages > 1) {
729
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
913
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
730
914
  constexpr bool use_cp_async = true;
731
915
  if (ncols2 > 1 || mask_h2) {
732
916
  flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
733
917
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
734
918
  }
735
919
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
736
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
920
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
737
921
  }
738
922
 
739
923
  // Iterate over ne11 == previous tokens:
740
924
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
741
925
  constexpr bool last_iter = false;
742
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
926
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
743
927
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
744
928
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
745
929
  }
746
930
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
747
931
  constexpr bool last_iter = true;
748
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
932
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
749
933
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
750
934
  ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
751
935
  }
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
774
958
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
775
959
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
776
960
 
777
- constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
961
+ constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
778
962
  constexpr int tile_stride = nbatch_combine + 4;
779
963
  static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
780
964
 
@@ -895,6 +1079,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
895
1079
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
896
1080
  dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
897
1081
  }
1082
+ } else if (np > 1) {
1083
+ // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
1084
+ // Therefore, all other warps also need to execute a __syncthreads().
1085
+ // Otherwise the points at which warps synchronize with each other would become misaligned.
1086
+ __syncthreads();
898
1087
  }
899
1088
 
900
1089
  #pragma unroll
@@ -1007,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1007
1196
  #endif // NEW_MMA_AVAILABLE
1008
1197
  }
1009
1198
 
1010
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
1199
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
1011
1200
  __launch_bounds__(nwarps*WARP_SIZE, 1)
1012
1201
  static __global__ void flash_attn_ext_f16(
1013
1202
  const char * __restrict__ Q,
@@ -1052,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
1052
1241
  NO_DEVICE_CODE;
1053
1242
  return;
1054
1243
  }
1244
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1245
+ if (ncols1*ncols2 > 32) {
1246
+ NO_DEVICE_CODE;
1247
+ return;
1248
+ }
1249
+ #endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1250
+
1251
+ static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1055
1252
 
1056
1253
  typedef fattn_mma_f16_config<DKQ, DV> c;
1057
1254
 
@@ -1062,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
1062
1259
  const int stride_Q1 = nb01 / sizeof(float2);
1063
1260
  const int stride_Q2 = nb02 / sizeof(float2);
1064
1261
  const int stride_K = nb11 / sizeof(half2);
1065
- const int stride_V = nb21 / sizeof(half2);
1066
1262
  const int stride_mask = nb31 / sizeof(half2);
1067
1263
 
1264
+ const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1265
+
1068
1266
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
1069
1267
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1070
1268
 
@@ -1087,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
1087
1285
 
1088
1286
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1089
1287
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1090
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1091
1288
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1092
1289
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1093
1290
 
1291
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1292
+
1094
1293
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1095
1294
 
1096
1295
  const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1099,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
1099
1298
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1100
1299
  if (kb0_start == 0) {
1101
1300
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1102
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1301
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1103
1302
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1104
1303
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1105
1304
  } else {
1106
1305
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1107
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1306
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1108
1307
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1109
1308
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1110
1309
  }
@@ -1125,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
1125
1324
 
1126
1325
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1127
1326
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1128
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
1129
1327
  const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1130
1328
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1131
1329
 
1330
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1331
+
1132
1332
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1133
1333
 
1134
1334
  const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1136,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
1136
1336
 
1137
1337
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1138
1338
  constexpr bool needs_fixup = false;
1139
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
1339
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1140
1340
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1141
1341
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1142
1342
  #else
@@ -1162,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1162
1362
 
1163
1363
  typedef fattn_mma_f16_config<DKQ, DV> c;
1164
1364
 
1165
- constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
1166
- constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
1167
- constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
1168
-
1169
1365
  const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1170
1366
 
1171
1367
  constexpr int ncols = ncols1 * ncols2;
@@ -1175,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1175
1371
  constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1176
1372
  constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1177
1373
 
1374
+ constexpr bool mla = DKQ == 576;
1375
+
1376
+ const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
1377
+ const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
1378
+ const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
1379
+
1178
1380
  static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1179
1381
  static_assert(DV % tile_A::J == 0, "bad DV");
1180
1382
  static_assert(ncols % cols_per_warp == 0, "bad ncols");
1181
1383
 
1182
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
1183
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
1184
- const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1185
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1186
- const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1384
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1385
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1386
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1387
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1388
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1187
1389
 
1188
1390
  const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1189
1391
 
@@ -1197,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1197
1399
  fattn_kernel_t fattn_kernel;
1198
1400
  if (logit_softcap == 0.0f) {
1199
1401
  constexpr bool use_logit_softcap = false;
1200
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1402
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1201
1403
 
1202
1404
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1203
1405
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1208,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1208
1410
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1209
1411
  } else {
1210
1412
  constexpr bool use_logit_softcap = true;
1211
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
1413
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1212
1414
 
1213
1415
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
1214
1416
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};