@novastera-oss/llamarn 0.4.0 → 0.4.3-beta4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (979) hide show
  1. package/RNLlamaCpp.podspec +4 -1
  2. package/android/CMakeLists.txt +13 -3
  3. package/android/src/main/cpp/include/llama.h +44 -21
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/LlamaCppModel.cpp +2 -10
  21. package/cpp/SystemUtils.cpp +3 -7
  22. package/cpp/build-info.cpp +2 -2
  23. package/cpp/llama.cpp/CMakeLists.txt +12 -0
  24. package/cpp/llama.cpp/CODEOWNERS +116 -10
  25. package/cpp/llama.cpp/CONTRIBUTING.md +30 -3
  26. package/cpp/llama.cpp/README.md +13 -5
  27. package/cpp/llama.cpp/build-xcframework.sh +5 -0
  28. package/cpp/llama.cpp/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  29. package/cpp/llama.cpp/common/CMakeLists.txt +12 -2
  30. package/cpp/llama.cpp/common/arg.cpp +303 -795
  31. package/cpp/llama.cpp/common/arg.h +2 -3
  32. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.cpp +861 -0
  33. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.h +45 -0
  34. package/cpp/llama.cpp/common/chat-parser.cpp +156 -15
  35. package/cpp/llama.cpp/common/chat-parser.h +13 -0
  36. package/cpp/llama.cpp/common/chat.cpp +1147 -88
  37. package/cpp/llama.cpp/common/chat.h +16 -3
  38. package/cpp/llama.cpp/common/common.cpp +70 -15
  39. package/cpp/llama.cpp/common/common.h +57 -19
  40. package/cpp/llama.cpp/common/download.cpp +1072 -0
  41. package/cpp/llama.cpp/common/download.h +55 -0
  42. package/cpp/llama.cpp/common/http.h +73 -0
  43. package/cpp/llama.cpp/common/json-partial.cpp +70 -2
  44. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +61 -22
  45. package/cpp/llama.cpp/common/json-schema-to-grammar.h +2 -0
  46. package/cpp/llama.cpp/common/log.cpp +59 -2
  47. package/cpp/llama.cpp/common/log.h +12 -4
  48. package/cpp/llama.cpp/common/sampling.cpp +84 -8
  49. package/cpp/llama.cpp/common/sampling.h +3 -1
  50. package/cpp/llama.cpp/common/speculative.cpp +1 -1
  51. package/cpp/llama.cpp/convert_hf_to_gguf.py +1608 -233
  52. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +6 -1
  53. package/cpp/llama.cpp/convert_lora_to_gguf.py +37 -5
  54. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -28
  55. package/cpp/llama.cpp/ggml/include/ggml-backend.h +19 -1
  56. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  57. package/cpp/llama.cpp/ggml/include/ggml-hexagon.h +19 -0
  58. package/cpp/llama.cpp/ggml/include/ggml-metal.h +1 -6
  59. package/cpp/llama.cpp/ggml/include/ggml-rpc.h +7 -9
  60. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +2 -1
  61. package/cpp/llama.cpp/ggml/include/ggml.h +199 -6
  62. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +38 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +299 -130
  64. package/cpp/llama.cpp/ggml/src/ggml-backend-impl.h +4 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +21 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +99 -2
  67. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +1 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  70. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +138 -47
  71. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +1584 -1773
  72. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +201 -317
  73. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +146 -187
  74. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +771 -713
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +135 -77
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +16 -17
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +318 -145
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +155 -60
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +8 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +14 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -9
  86. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +108 -64
  87. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +14 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +530 -87
  89. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +37 -45
  90. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +349 -127
  91. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +947 -1218
  92. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +143 -29
  94. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +82 -76
  95. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +7 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +233 -28
  102. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +326 -66
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +12 -3
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/argsort.cu +102 -6
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +110 -76
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +167 -38
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +6 -11
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +12 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +245 -151
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cuh +1 -5
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +341 -289
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh +1233 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +6 -6
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +123 -220
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +41 -39
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +715 -45
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +150 -0
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cuh +1 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +321 -24
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +93 -351
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +828 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cu +164 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cuh +5 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +3 -166
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1 -1
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cu +371 -78
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +279 -147
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +97 -85
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad.cu +46 -23
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +63 -54
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +12 -10
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +192 -77
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +10 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +137 -75
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cu +39 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cuh +7 -0
  144. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  152. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  167. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  173. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  174. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  176. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  178. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  179. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  180. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  183. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  184. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  186. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  187. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  188. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  189. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  190. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  195. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  196. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  197. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  198. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  199. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  201. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  202. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  203. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  204. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +336 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +16 -0
  208. package/cpp/llama.cpp/ggml/src/ggml-cuda/tsembd.cu +3 -3
  209. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +105 -11
  210. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +36 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +87 -6
  212. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +28 -12
  213. package/cpp/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt +68 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3807 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +40 -0
  216. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +442 -0
  217. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  218. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  219. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +40 -0
  220. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +69 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +119 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +156 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +64 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  225. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.c +93 -0
  226. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.c +60 -0
  227. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  228. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.c +960 -0
  229. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +1032 -0
  230. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +829 -0
  231. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2223 -0
  232. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  233. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +418 -0
  234. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  235. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +255 -0
  236. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  237. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  238. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.c +448 -0
  239. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +220 -0
  240. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  241. package/cpp/llama.cpp/ggml/src/ggml-impl.h +110 -12
  242. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +6 -5
  243. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  244. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  245. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  246. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m +599 -0
  247. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1662 -0
  248. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +251 -0
  249. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +1527 -0
  250. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +244 -39
  251. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +3844 -0
  252. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +90 -0
  253. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp +723 -0
  254. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +3453 -1907
  255. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  256. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +10 -0
  257. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1331 -109
  258. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +126 -0
  259. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +31 -4
  260. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +35 -7
  261. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +31 -4
  262. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  263. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  264. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  265. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  266. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  267. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  268. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  269. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  270. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  271. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  272. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  273. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  274. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  275. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  276. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  277. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  278. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +123 -10
  279. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  280. package/cpp/llama.cpp/ggml/src/ggml-quants.c +1 -0
  281. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +341 -161
  282. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  283. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  284. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +74 -15
  285. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +50 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +10 -4
  287. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +166 -99
  288. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  289. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  290. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +72 -94
  291. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  292. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +21 -31
  293. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +252 -316
  294. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +6 -2
  295. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +9 -6
  296. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +359 -142
  297. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  298. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  299. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +80 -60
  300. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  301. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +230 -55
  302. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.hpp +2 -0
  303. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.cpp +97 -0
  304. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.hpp +24 -0
  305. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.cpp +72 -0
  306. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.hpp +8 -0
  307. package/cpp/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  308. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  309. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  310. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.cpp +122 -0
  311. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.hpp +20 -0
  312. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +50 -41
  313. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.cpp +73 -0
  314. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.hpp +5 -0
  315. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +45 -36
  316. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +330 -165
  317. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +4 -0
  318. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  319. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  320. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  321. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +16 -12
  322. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  323. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4184 -2159
  324. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  325. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  326. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  327. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  328. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  329. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  330. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  331. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  332. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  333. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  334. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  335. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  336. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  337. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  338. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +53 -30
  339. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  340. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  341. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  342. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +13 -6
  343. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  344. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  345. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  346. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  347. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +138 -2
  348. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  349. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  350. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  351. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  352. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  353. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  354. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  355. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  356. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  357. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  358. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  359. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  360. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  361. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  362. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  363. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  364. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  365. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  366. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  367. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  368. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  369. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  370. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  371. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  372. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -2
  373. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  374. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +52 -14
  375. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +50 -12
  376. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +61 -12
  377. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +54 -12
  378. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +5 -1
  379. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  380. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  381. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  382. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  383. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  384. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  385. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  386. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +10 -2
  387. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  388. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  389. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  390. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  391. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  392. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  393. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +15 -7
  394. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  395. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  396. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  397. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  398. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  399. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +1 -1
  400. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +229 -0
  401. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +33 -0
  402. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +1 -1
  403. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +1 -1
  404. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +1 -1
  405. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +1 -1
  406. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +1 -1
  407. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +1 -1
  408. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +1 -1
  409. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  410. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  411. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +3 -5
  412. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +1 -1
  413. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +3 -5
  414. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +3 -5
  415. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +1 -1
  416. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  417. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +106 -634
  418. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +118 -9
  419. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +556 -0
  420. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +70 -0
  421. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +77 -214
  422. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +589 -0
  423. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  424. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  425. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  426. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  427. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  428. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  429. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +25 -4
  430. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  431. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +55 -5
  432. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  433. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  434. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  435. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  436. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +45 -3
  437. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  438. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  439. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  440. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +227 -0
  441. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  442. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +5 -52
  443. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +5 -35
  444. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +5 -35
  445. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +27 -0
  446. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +5 -41
  447. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  448. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  449. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  450. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  451. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  452. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  453. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  454. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  455. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  456. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  457. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  458. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  459. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +140 -0
  460. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  461. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  462. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +1 -1
  463. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  464. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  465. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  466. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  467. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +171 -0
  468. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  469. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +79 -29
  470. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -12
  471. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +471 -196
  472. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +8 -0
  473. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1690 -383
  474. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  475. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  476. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  477. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  478. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +57 -10
  479. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  480. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  481. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +25 -912
  482. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  483. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  484. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  485. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  486. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  487. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  488. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  489. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl → set_rows.tmpl.wgsl} +38 -8
  490. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  491. package/cpp/llama.cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  492. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +96 -314
  493. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  494. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  495. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  496. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  497. package/cpp/llama.cpp/ggml/src/ggml.c +440 -17
  498. package/cpp/llama.cpp/ggml/src/gguf.cpp +104 -29
  499. package/cpp/llama.cpp/gguf-py/gguf/constants.py +363 -13
  500. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +64 -0
  501. package/cpp/llama.cpp/gguf-py/gguf/lazy.py +8 -3
  502. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +6 -0
  503. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +156 -18
  504. package/cpp/llama.cpp/gguf-py/gguf/utility.py +80 -0
  505. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +4 -4
  506. package/cpp/llama.cpp/include/llama.h +44 -21
  507. package/cpp/llama.cpp/media/llama1-icon-transparent.png +0 -0
  508. package/cpp/llama.cpp/media/llama1-icon-transparent.svg +77 -0
  509. package/cpp/llama.cpp/media/llama1-icon.png +0 -0
  510. package/cpp/llama.cpp/media/llama1-icon.svg +87 -0
  511. package/cpp/llama.cpp/requirements/requirements-all.txt +2 -0
  512. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -3
  513. package/cpp/llama.cpp/requirements/requirements-convert_legacy_llama.txt +3 -1
  514. package/cpp/llama.cpp/requirements/requirements-tool_bench.txt +1 -1
  515. package/cpp/llama.cpp/src/CMakeLists.txt +101 -0
  516. package/cpp/llama.cpp/src/llama-adapter.cpp +33 -0
  517. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  518. package/cpp/llama.cpp/src/llama-arch.cpp +344 -14
  519. package/cpp/llama.cpp/src/llama-arch.h +50 -0
  520. package/cpp/llama.cpp/src/llama-batch.cpp +63 -31
  521. package/cpp/llama.cpp/src/llama-batch.h +13 -2
  522. package/cpp/llama.cpp/src/llama-chat.cpp +85 -3
  523. package/cpp/llama.cpp/src/llama-chat.h +4 -0
  524. package/cpp/llama.cpp/src/llama-context.cpp +300 -45
  525. package/cpp/llama.cpp/src/llama-context.h +16 -6
  526. package/cpp/llama.cpp/src/llama-cparams.h +2 -1
  527. package/cpp/llama.cpp/src/llama-grammar.cpp +17 -9
  528. package/cpp/llama.cpp/src/llama-graph.cpp +226 -64
  529. package/cpp/llama.cpp/src/llama-graph.h +27 -5
  530. package/cpp/llama.cpp/src/llama-hparams.cpp +53 -2
  531. package/cpp/llama.cpp/src/llama-hparams.h +48 -8
  532. package/cpp/llama.cpp/src/llama-impl.cpp +3 -3
  533. package/cpp/llama.cpp/src/llama-impl.h +2 -0
  534. package/cpp/llama.cpp/src/llama-kv-cache-iswa.cpp +13 -3
  535. package/cpp/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
  536. package/cpp/llama.cpp/src/llama-kv-cache.cpp +120 -62
  537. package/cpp/llama.cpp/src/llama-kv-cache.h +13 -4
  538. package/cpp/llama.cpp/src/llama-kv-cells.h +44 -2
  539. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +19 -9
  540. package/cpp/llama.cpp/src/llama-memory-hybrid.h +2 -0
  541. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +38 -17
  542. package/cpp/llama.cpp/src/llama-memory-recurrent.h +5 -2
  543. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  544. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  545. package/cpp/llama.cpp/src/llama-model.cpp +1070 -12614
  546. package/cpp/llama.cpp/src/llama-model.h +40 -4
  547. package/cpp/llama.cpp/src/llama-quant.cpp +14 -6
  548. package/cpp/llama.cpp/src/llama-sampling.cpp +243 -136
  549. package/cpp/llama.cpp/src/llama-vocab.cpp +43 -3
  550. package/cpp/llama.cpp/src/llama-vocab.h +43 -39
  551. package/cpp/llama.cpp/src/llama.cpp +69 -10
  552. package/cpp/llama.cpp/src/models/afmoe.cpp +187 -0
  553. package/cpp/llama.cpp/src/models/apertus.cpp +125 -0
  554. package/cpp/llama.cpp/src/models/arcee.cpp +135 -0
  555. package/cpp/llama.cpp/src/models/arctic.cpp +138 -0
  556. package/cpp/llama.cpp/src/models/arwkv7.cpp +86 -0
  557. package/cpp/llama.cpp/src/models/baichuan.cpp +122 -0
  558. package/cpp/llama.cpp/src/models/bailingmoe.cpp +144 -0
  559. package/cpp/llama.cpp/src/models/bailingmoe2.cpp +135 -0
  560. package/cpp/llama.cpp/src/models/bert.cpp +176 -0
  561. package/cpp/llama.cpp/src/models/bitnet.cpp +160 -0
  562. package/cpp/llama.cpp/src/models/bloom.cpp +101 -0
  563. package/cpp/llama.cpp/src/models/chameleon.cpp +178 -0
  564. package/cpp/llama.cpp/src/models/chatglm.cpp +132 -0
  565. package/cpp/llama.cpp/src/models/codeshell.cpp +111 -0
  566. package/cpp/llama.cpp/src/models/cogvlm.cpp +100 -0
  567. package/cpp/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
  568. package/cpp/llama.cpp/src/models/command-r.cpp +122 -0
  569. package/cpp/llama.cpp/src/models/dbrx.cpp +123 -0
  570. package/cpp/llama.cpp/src/models/deci.cpp +135 -0
  571. package/cpp/llama.cpp/src/models/deepseek.cpp +144 -0
  572. package/cpp/llama.cpp/src/models/deepseek2.cpp +237 -0
  573. package/cpp/llama.cpp/src/models/dots1.cpp +134 -0
  574. package/cpp/llama.cpp/src/models/dream.cpp +105 -0
  575. package/cpp/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
  576. package/cpp/llama.cpp/src/models/ernie4-5.cpp +110 -0
  577. package/cpp/llama.cpp/src/models/exaone.cpp +114 -0
  578. package/cpp/llama.cpp/src/models/exaone4.cpp +123 -0
  579. package/cpp/llama.cpp/src/models/falcon-h1.cpp +113 -0
  580. package/cpp/llama.cpp/src/models/falcon.cpp +120 -0
  581. package/cpp/llama.cpp/src/models/gemma-embedding.cpp +120 -0
  582. package/cpp/llama.cpp/src/models/gemma.cpp +112 -0
  583. package/cpp/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
  584. package/cpp/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
  585. package/cpp/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
  586. package/cpp/llama.cpp/src/models/glm4-moe.cpp +153 -0
  587. package/cpp/llama.cpp/src/models/glm4.cpp +127 -0
  588. package/cpp/llama.cpp/src/models/gpt2.cpp +105 -0
  589. package/cpp/llama.cpp/src/models/gptneox.cpp +144 -0
  590. package/cpp/llama.cpp/src/models/granite-hybrid.cpp +196 -0
  591. package/cpp/llama.cpp/src/models/granite.cpp +211 -0
  592. package/cpp/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
  593. package/cpp/llama.cpp/src/models/grok.cpp +159 -0
  594. package/cpp/llama.cpp/src/models/grovemoe.cpp +141 -0
  595. package/cpp/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
  596. package/cpp/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
  597. package/cpp/llama.cpp/src/models/internlm2.cpp +120 -0
  598. package/cpp/llama.cpp/src/models/jais.cpp +86 -0
  599. package/cpp/llama.cpp/src/models/jamba.cpp +106 -0
  600. package/cpp/llama.cpp/src/models/lfm2.cpp +173 -0
  601. package/cpp/llama.cpp/src/models/llada-moe.cpp +122 -0
  602. package/cpp/llama.cpp/src/models/llada.cpp +99 -0
  603. package/cpp/llama.cpp/src/models/llama-iswa.cpp +174 -0
  604. package/cpp/llama.cpp/src/models/llama.cpp +155 -0
  605. package/cpp/llama.cpp/src/models/mamba.cpp +55 -0
  606. package/cpp/llama.cpp/src/models/minicpm3.cpp +199 -0
  607. package/cpp/llama.cpp/src/models/minimax-m2.cpp +124 -0
  608. package/cpp/llama.cpp/src/models/models.h +485 -0
  609. package/cpp/llama.cpp/src/models/mpt.cpp +126 -0
  610. package/cpp/llama.cpp/src/models/nemotron-h.cpp +121 -0
  611. package/cpp/llama.cpp/src/models/nemotron.cpp +122 -0
  612. package/cpp/llama.cpp/src/models/neo-bert.cpp +104 -0
  613. package/cpp/llama.cpp/src/models/olmo.cpp +121 -0
  614. package/cpp/llama.cpp/src/models/olmo2.cpp +150 -0
  615. package/cpp/llama.cpp/src/models/olmoe.cpp +124 -0
  616. package/cpp/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
  617. package/cpp/llama.cpp/src/models/openelm.cpp +124 -0
  618. package/cpp/llama.cpp/src/models/orion.cpp +123 -0
  619. package/cpp/llama.cpp/src/models/pangu-embedded.cpp +121 -0
  620. package/cpp/llama.cpp/src/models/phi2.cpp +121 -0
  621. package/cpp/llama.cpp/src/models/phi3.cpp +152 -0
  622. package/cpp/llama.cpp/src/models/plamo.cpp +110 -0
  623. package/cpp/llama.cpp/src/models/plamo2.cpp +316 -0
  624. package/cpp/llama.cpp/src/models/plm.cpp +168 -0
  625. package/cpp/llama.cpp/src/models/qwen.cpp +108 -0
  626. package/cpp/llama.cpp/src/models/qwen2.cpp +117 -0
  627. package/cpp/llama.cpp/src/models/qwen2moe.cpp +151 -0
  628. package/cpp/llama.cpp/src/models/qwen2vl.cpp +117 -0
  629. package/cpp/llama.cpp/src/models/qwen3.cpp +117 -0
  630. package/cpp/llama.cpp/src/models/qwen3moe.cpp +124 -0
  631. package/cpp/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
  632. package/cpp/llama.cpp/src/models/qwen3vl.cpp +141 -0
  633. package/cpp/llama.cpp/src/models/refact.cpp +94 -0
  634. package/cpp/llama.cpp/src/models/rwkv6-base.cpp +162 -0
  635. package/cpp/llama.cpp/src/models/rwkv6.cpp +94 -0
  636. package/cpp/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
  637. package/cpp/llama.cpp/src/models/rwkv7-base.cpp +135 -0
  638. package/cpp/llama.cpp/src/models/rwkv7.cpp +90 -0
  639. package/cpp/llama.cpp/src/models/seed-oss.cpp +124 -0
  640. package/cpp/llama.cpp/src/models/smallthinker.cpp +120 -0
  641. package/cpp/llama.cpp/src/models/smollm3.cpp +128 -0
  642. package/cpp/llama.cpp/src/models/stablelm.cpp +146 -0
  643. package/cpp/llama.cpp/src/models/starcoder.cpp +100 -0
  644. package/cpp/llama.cpp/src/models/starcoder2.cpp +121 -0
  645. package/cpp/llama.cpp/src/models/t5-dec.cpp +166 -0
  646. package/cpp/llama.cpp/src/models/t5-enc.cpp +96 -0
  647. package/cpp/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
  648. package/cpp/llama.cpp/src/models/xverse.cpp +108 -0
  649. package/cpp/llama.cpp/src/unicode.cpp +77 -0
  650. package/cpp/llama.cpp/src/unicode.h +43 -0
  651. package/cpp/llama.cpp/vendor/cpp-httplib/CMakeLists.txt +94 -0
  652. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.cpp +9339 -0
  653. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +433 -8222
  654. package/cpp/llama.cpp/vendor/cpp-httplib/patch-boringssl.cmake +6 -0
  655. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +4179 -1900
  656. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +9 -2
  657. package/cpp/llama.cpp/vendor/minja/minja.hpp +101 -22
  658. package/cpp/rn-completion.cpp +3 -27
  659. package/ios/include/chat.h +16 -3
  660. package/ios/include/common/minja/chat-template.hpp +9 -2
  661. package/ios/include/common/minja/minja.hpp +101 -22
  662. package/ios/include/common.h +57 -19
  663. package/ios/include/json-schema-to-grammar.h +2 -0
  664. package/ios/include/llama.h +44 -21
  665. package/ios/include/log.h +12 -4
  666. package/ios/include/sampling.h +3 -1
  667. package/ios/libs/llama.xcframework/Info.plist +20 -20
  668. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  669. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6399 -5557
  670. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +19 -1
  671. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +1 -1
  672. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-metal.h +1 -6
  673. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +199 -6
  674. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +44 -21
  675. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  676. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  677. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6362 -5520
  678. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4813 -4241
  679. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +19 -1
  680. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +1 -1
  681. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +1 -6
  682. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +199 -6
  683. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +44 -21
  684. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  685. package/package.json +10 -4
  686. package/cpp/llama.cpp/ggml/src/ggml-cann/Doxyfile +0 -2579
  687. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -371
  688. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  689. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -379
  690. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  691. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -495
  692. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -486
  693. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  694. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  695. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  696. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  697. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  698. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  699. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  700. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  701. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  702. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  703. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  704. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  705. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  706. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  707. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  708. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  709. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  710. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  711. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  712. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  713. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  714. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  715. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  716. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  717. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  718. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  719. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  720. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  721. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  722. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  723. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  724. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  725. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  726. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  727. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  728. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  729. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  730. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  731. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  732. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  733. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  734. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  735. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  736. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  737. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  738. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  739. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  740. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  741. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  742. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  743. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  744. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  745. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  746. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  747. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  748. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  749. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  750. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  751. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  752. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  753. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  754. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  755. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  756. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  757. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  758. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  759. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  760. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  761. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  762. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  763. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  764. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  765. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  766. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  767. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  768. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  769. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  770. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  771. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  772. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  773. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  774. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  775. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  776. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  777. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  778. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  779. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +0 -6886
  780. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -154
  781. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  782. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  783. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  784. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +0 -97
  785. package/cpp/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  786. package/cpp/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  787. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  788. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +0 -112
  789. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +0 -46
  790. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  791. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +0 -112
  792. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +0 -46
  793. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  794. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +0 -112
  795. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +0 -46
  796. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  797. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +0 -112
  798. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +0 -46
  799. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  800. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +0 -112
  801. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +0 -46
  802. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  803. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +0 -112
  804. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +0 -46
  805. package/cpp/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  806. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  807. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +0 -112
  808. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +0 -46
  809. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  810. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +0 -112
  811. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +0 -46
  812. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  813. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +0 -112
  814. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +0 -46
  815. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  816. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  817. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +0 -112
  818. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +0 -46
  819. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  820. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +0 -112
  821. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +0 -46
  822. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  823. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +0 -112
  824. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +0 -46
  825. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  826. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +0 -112
  827. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +0 -46
  828. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +0 -171
  829. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +0 -202
  830. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +0 -156
  831. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +0 -124
  832. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +0 -152
  833. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +0 -152
  834. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +0 -62
  835. package/cpp/llama.cpp/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +0 -54
  836. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +0 -85
  837. package/cpp/llama.cpp/models/templates/README.md +0 -25
  838. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +0 -1
  839. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +0 -1
  840. package/cpp/llama.cpp/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +0 -57
  841. package/cpp/llama.cpp/models/templates/google-gemma-2-2b-it.jinja +0 -4
  842. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +0 -59
  843. package/cpp/llama.cpp/models/templates/llama-cpp-deepseek-r1.jinja +0 -76
  844. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +0 -34
  845. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.1.jinja +0 -58
  846. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.2.jinja +0 -287
  847. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +0 -109
  848. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +0 -93
  849. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +0 -109
  850. package/cpp/llama.cpp/models/templates/microsoft-Phi-3.5-mini-instruct.jinja +0 -8
  851. package/cpp/llama.cpp/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +0 -87
  852. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +0 -43
  853. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +0 -331
  854. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +0 -105
  855. package/cpp/llama.cpp/prompts/LLM-questions.txt +0 -49
  856. package/cpp/llama.cpp/prompts/alpaca.txt +0 -1
  857. package/cpp/llama.cpp/prompts/assistant.txt +0 -31
  858. package/cpp/llama.cpp/prompts/chat-with-baichuan.txt +0 -4
  859. package/cpp/llama.cpp/prompts/chat-with-bob.txt +0 -7
  860. package/cpp/llama.cpp/prompts/chat-with-qwen.txt +0 -1
  861. package/cpp/llama.cpp/prompts/chat-with-vicuna-v0.txt +0 -7
  862. package/cpp/llama.cpp/prompts/chat-with-vicuna-v1.txt +0 -7
  863. package/cpp/llama.cpp/prompts/chat.txt +0 -28
  864. package/cpp/llama.cpp/prompts/dan-modified.txt +0 -1
  865. package/cpp/llama.cpp/prompts/dan.txt +0 -1
  866. package/cpp/llama.cpp/prompts/mnemonics.txt +0 -93
  867. package/cpp/llama.cpp/prompts/parallel-questions.txt +0 -43
  868. package/cpp/llama.cpp/prompts/reason-act.txt +0 -18
  869. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  870. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  871. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  872. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4247
  873. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-alloc.h +0 -76
  874. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +0 -354
  875. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-blas.h +0 -25
  876. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +0 -145
  877. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-metal.h +0 -66
  878. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +0 -256
  879. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +0 -2492
  880. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/gguf.h +0 -202
  881. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -1391
  882. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Modules/module.modulemap +0 -17
  883. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Resources/Info.plist +0 -32
  884. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-alloc.h +0 -76
  885. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +0 -354
  886. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-blas.h +0 -25
  887. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +0 -145
  888. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-metal.h +0 -66
  889. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +0 -256
  890. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +0 -2492
  891. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/gguf.h +0 -202
  892. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -1391
  893. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Modules/module.modulemap +0 -17
  894. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Resources/Info.plist +0 -32
  895. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  896. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-alloc.h +0 -76
  897. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +0 -354
  898. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-blas.h +0 -25
  899. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +0 -145
  900. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-metal.h +0 -66
  901. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +0 -256
  902. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +0 -2492
  903. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/gguf.h +0 -202
  904. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -1391
  905. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Modules/module.modulemap +0 -17
  906. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Resources/Info.plist +0 -32
  907. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  908. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  909. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  910. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  911. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5561
  912. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  913. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  914. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  915. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  916. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  917. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  918. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +0 -2492
  919. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/gguf.h +0 -202
  920. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -1391
  921. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Info.plist +0 -35
  922. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Modules/module.modulemap +0 -17
  923. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  924. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  925. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  926. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  927. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4246
  928. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  929. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  930. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  931. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  932. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  933. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  934. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  935. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  936. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  937. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Info.plist +0 -35
  938. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  939. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  940. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  941. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  942. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5558
  943. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  944. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  945. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  946. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  947. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  948. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  949. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +0 -2492
  950. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/gguf.h +0 -202
  951. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -1391
  952. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Info.plist +0 -32
  953. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Modules/module.modulemap +0 -17
  954. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  955. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  956. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  957. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5520
  958. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4243
  959. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  960. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  961. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  962. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  963. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  964. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  965. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  966. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  967. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  968. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Info.plist +0 -32
  969. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  970. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  971. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  972. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  973. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  974. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  975. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +0 -0
  976. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +0 -0
  977. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  978. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  979. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -0,0 +1,1233 @@
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-wmma-f16.cuh"
4
+
5
+ // nbatch_fa == number of KQ rows to process per iteration
6
+ // nbatch_K == number of K columns to load in parallel for KQ calculation
7
+
8
+ // TODO optimize kernel parameters for FP16 NVIDIA (P100)
9
+ // TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
10
+
11
+ // The ROCm compiler cannot handle templating in __launch_bounds__.
12
+ // As a workaround, define a macro to package the kernel parameters as uint32_t:
13
+ #define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
14
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
15
+ static_assert((nthreads) <= 512, "bad nthreads"); \
16
+ static_assert((occupancy) <= 8, "bad occupancy"); \
17
+ static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \
18
+ static_assert((nbatch_K) <= 256, "bad nbatch_K"); \
19
+ return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \
20
+ } \
21
+
22
+ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
23
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40)
24
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40)
25
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40)
26
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40)
27
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40)
28
+
29
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64)
30
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64)
31
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64)
32
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
33
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
34
+
35
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
36
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
37
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
38
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
39
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
40
+
41
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
42
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
43
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
44
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
45
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
46
+
47
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
48
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
49
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
50
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48)
51
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48)
52
+
53
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56)
54
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56)
55
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56)
56
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56)
57
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56)
58
+
59
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64)
60
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64)
61
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64)
62
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
63
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
64
+
65
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
66
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
67
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
68
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
69
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
70
+
71
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
72
+
73
+ return 0;
74
+ }
75
+
76
+ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
77
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
78
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
79
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
80
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
81
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
82
+
83
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64)
84
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64)
85
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64)
86
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
87
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
88
+
89
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
90
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
91
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
92
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
93
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
94
+
95
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
96
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
97
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
98
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
99
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
100
+
101
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
102
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
103
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
104
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
105
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
106
+
107
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
108
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
109
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
110
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
111
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
112
+
113
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64)
114
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128)
115
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128)
116
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
117
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
118
+
119
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
120
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
121
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
122
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
123
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
124
+
125
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
126
+
127
+ return 0;
128
+ }
129
+
130
+ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
131
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
132
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
133
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
134
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
135
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
136
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
137
+
138
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
139
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
140
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
141
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
142
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
143
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
144
+
145
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
146
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
147
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
148
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
149
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
150
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
151
+
152
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
153
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
154
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
155
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
156
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
157
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
158
+
159
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
160
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
161
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
162
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
163
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
164
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
165
+
166
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
167
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
168
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
169
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
170
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
171
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
172
+
173
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
174
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
175
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
176
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
177
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
178
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
179
+
180
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
181
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
182
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
183
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
184
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
185
+
186
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
187
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
188
+
189
+ return 0;
190
+ }
191
+
192
+ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
193
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
194
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
195
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
196
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
197
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
198
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
199
+
200
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
201
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
202
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
203
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
204
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
205
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
206
+
207
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
208
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
209
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
210
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
211
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
212
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
213
+
214
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
215
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
216
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
217
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
218
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
219
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
220
+
221
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
222
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
223
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
224
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
225
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
226
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
227
+
228
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
229
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
230
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
231
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
232
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
233
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
234
+
235
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
236
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
237
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
238
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
239
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
240
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
241
+
242
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
243
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
244
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
245
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
246
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
247
+
248
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
249
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
250
+
251
+ return 0;
252
+ }
253
+
254
+ static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
255
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
256
+ if (GGML_CUDA_CC_IS_RDNA(cc)) {
257
+ return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
258
+ }
259
+ return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
260
+ }
261
+ if (fast_fp16_available(cc)) {
262
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
263
+ }
264
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
265
+ }
266
+
267
+ static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
268
+ #ifdef GGML_USE_HIP
269
+ #ifdef RDNA
270
+ return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
271
+ #else
272
+ return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
273
+ #endif // RDNA
274
+ #else
275
+ #ifdef FAST_FP16_AVAILABLE
276
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
277
+ #else
278
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
279
+ #endif // FAST_FP16_AVAILABLE
280
+ #endif // GGML_USE_HIP
281
+ }
282
+
283
+ static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
284
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
285
+ }
286
+
287
+ static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
288
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
289
+ }
290
+
291
+ static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
292
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
293
+ }
294
+
295
+ static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
296
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
297
+ }
298
+
299
+ static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
300
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
301
+ }
302
+
303
+ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
304
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
305
+ }
306
+
307
+ static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
308
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
309
+ }
310
+
311
+ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
312
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
313
+ }
314
+
315
+ // TODO: deduplicate with mma-f16
316
+ template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
317
+ static __device__ __forceinline__ void flash_attn_tile_load_tile(
318
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
319
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
320
+ constexpr int cpy_ne = cpy_nb / 4;
321
+
322
+ auto load = [&] __device__ (const int n) {
323
+ const int stride_j = warp_size >> n;
324
+
325
+ if (stride_j == 0) {
326
+ return;
327
+ }
328
+
329
+ const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
330
+ const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
331
+ const int stride_i = warp_size / stride_j;
332
+
333
+ if (j0_start == j0_stop) {
334
+ return;
335
+ }
336
+
337
+ #pragma unroll
338
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
339
+ const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
340
+
341
+ if (i0 + nwarps*stride_i <= I || i < I) {
342
+ #pragma unroll
343
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
344
+ const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
345
+
346
+ const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
347
+ ggml_cuda_memcpy_1<cpy_nb>(
348
+ tile_KV + i*(J/2 + J_padding) + j,
349
+ !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
350
+ }
351
+ }
352
+ }
353
+ };
354
+ // 1: max 64*16=512 bytes, 512 half
355
+ // 2: max 32*16=512 bytes, 256 half
356
+ // 3: max 16*16=256 bytes, 128 half
357
+ // 4: max 8*16=128 bytes, 64 half
358
+ // 5: max 4*16= 64 bytes, 32 half
359
+ // 6: max 2*16= 32 bytes, 16 half
360
+ // 7: max 1*16= 16 bytes, 8 half
361
+ static_assert(J % 8 == 0, "bad J");
362
+ static_assert((J/2) % cpy_ne == 0, "bad J");
363
+ ggml_cuda_unroll<7>{}(load);
364
+ }
365
+
366
+ template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
367
+ static __device__ __forceinline__ void flash_attn_tile_load_tile(
368
+ const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
369
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
370
+ constexpr int cpy_ne = cpy_nb / 4;
371
+
372
+ auto load = [&] __device__ (const int n) {
373
+ const int stride_j = warp_size >> n;
374
+
375
+ if (stride_j == 0) {
376
+ return;
377
+ }
378
+
379
+ const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
380
+ const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
381
+ const int stride_i = warp_size / stride_j;
382
+
383
+ if (j0_start == j0_stop) {
384
+ return;
385
+ }
386
+
387
+ #pragma unroll
388
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
389
+ const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
390
+
391
+ if (i0 + nwarps*stride_i <= I || i < I) {
392
+ #pragma unroll
393
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
394
+ const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
395
+
396
+ const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
397
+ half2 tmp_h2[cpy_ne/2];
398
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
399
+ tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
400
+
401
+ float2 tmp_f2[cpy_ne/2];
402
+ #pragma unroll
403
+ for (int l = 0; l < cpy_ne/2; ++l) {
404
+ tmp_f2[l] = __half22float2(tmp_h2[l]);
405
+ }
406
+ ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
407
+ }
408
+ }
409
+ }
410
+ };
411
+ // 1: max 32*16=512 bytes, 128 float
412
+ // 2: max 16*16=256 bytes, 64 float
413
+ // 3: max 8*16=128 bytes, 32 float
414
+ // 4: max 4*16= 64 bytes, 16 float
415
+ // 5: max 2*16= 32 bytes, 8 float
416
+ static_assert(J % 8 == 0, "bad J");
417
+ static_assert(J % cpy_ne == 0, "bad J");
418
+ ggml_cuda_unroll<5>{}(load);
419
+ }
420
+
421
+ // Function that performs a single iteration in for the KQ matrix multiplication:
422
+ template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
423
+ bool use_logit_softcap, bool oob_check, typename T_vec_dot>
424
+ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
425
+ T_vec_dot * const Q_tmp,
426
+ const half2 * const __restrict__ K_h2,
427
+ T_vec_dot * const KV_tmp,
428
+ const int stride_K2,
429
+ const int k_VKQ_0,
430
+ const int k_VKQ_sup,
431
+ const int k_KQ_0,
432
+ float * KQ_acc) {
433
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
434
+ constexpr int cpy_ne = cpy_nb / 4;
435
+
436
+ constexpr int ncols = ncols1*ncols2;
437
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
438
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
439
+
440
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
441
+ (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
442
+ __syncthreads();
443
+
444
+ #ifdef FAST_FP16_AVAILABLE
445
+ static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
446
+ #pragma unroll
447
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
448
+ half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
449
+ half2 Q_k[cpw][cpy_ne];
450
+ #else
451
+ static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
452
+ #pragma unroll
453
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
454
+ float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
455
+ float Q_k[cpw][cpy_ne];
456
+ #endif // FAST_FP16_AVAILABLE
457
+
458
+ #pragma unroll
459
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
460
+ const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
461
+
462
+ #ifdef FAST_FP16_AVAILABLE
463
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
464
+ #else
465
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]);
466
+ #endif // FAST_FP16_AVAILABLE
467
+ }
468
+ #pragma unroll
469
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
470
+ const int jc = jc0 + (threadIdx.y / np)*cpw;
471
+
472
+ #ifdef FAST_FP16_AVAILABLE
473
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
474
+ #else
475
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]);
476
+ #endif // FAST_FP16_AVAILABLE
477
+ }
478
+
479
+ #pragma unroll
480
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
481
+ #pragma unroll
482
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
483
+ #pragma unroll
484
+ for (int k = 0; k < cpy_ne; ++k) {
485
+ ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
486
+ }
487
+ }
488
+ }
489
+ }
490
+
491
+ if (k_KQ_0 + nbatch_K < DKQ) {
492
+ __syncthreads(); // Sync not needed on last iteration.
493
+ }
494
+ }
495
+
496
+ // Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
497
+ template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
498
+ bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
499
+ static __device__ __forceinline__ void flash_attn_tile_iter(
500
+ T_vec_dot * const Q_tmp,
501
+ const half2 * const __restrict__ K_h2,
502
+ const half2 * const __restrict__ V_h2,
503
+ const half * const __restrict__ mask,
504
+ const float logit_softcap,
505
+ const float slope,
506
+ T_KQ * const KQ,
507
+ T_vec_dot * const KV_tmp,
508
+ const int stride_K2,
509
+ const int stride_V2,
510
+ const int stride_mask,
511
+ float * const KQ_max,
512
+ float * const KQ_sum,
513
+ T_acc * const VKQ,
514
+ const int k_VKQ_0,
515
+ const int k_VKQ_max) {
516
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
517
+ constexpr int cpy_ne = cpy_nb / 4;
518
+
519
+ constexpr int ncols = ncols1*ncols2;
520
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
521
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
522
+
523
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
524
+
525
+ // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
526
+ // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
527
+ #ifdef FAST_FP16_AVAILABLE
528
+ constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
529
+ #else
530
+ constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
531
+ #endif // FAST_FP16_AVAILABLE
532
+ static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
533
+ const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
534
+
535
+ float KQ_max_new[cpw];
536
+ #pragma unroll
537
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
538
+ KQ_max_new[jc0] = KQ_max[jc0];
539
+ }
540
+
541
+ float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
542
+
543
+ // KQ = K @ Q matrix multiplication:
544
+ constexpr int nbatch_K_last = DKQ % nbatch_K;
545
+ #pragma unroll
546
+ for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
547
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
548
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
549
+ }
550
+ if (nbatch_K_last > 0) {
551
+ constexpr int k_KQ_0 = DKQ - nbatch_K_last;
552
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
553
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
554
+ }
555
+
556
+ // Apply logit softcap + mask, update KQ_max:
557
+ #pragma unroll
558
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
559
+ const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2;
560
+
561
+ #pragma unroll
562
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
563
+ const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
564
+
565
+ if (use_logit_softcap) {
566
+ KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
567
+ }
568
+
569
+ if (!oob_check || i_KQ < k_VKQ_sup) {
570
+ KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
571
+ slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
572
+
573
+ KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
574
+ }
575
+ }
576
+
577
+ KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
578
+ }
579
+
580
+ if constexpr (np == 1) {
581
+ __syncthreads();
582
+ } else {
583
+ static_assert(cpw == 1, "bad cpw");
584
+ __shared__ float KQ_max_new_shared[nwarps];
585
+ if (threadIdx.x == 0) {
586
+ KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
587
+ }
588
+ __syncthreads();
589
+ KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
590
+ KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
591
+ }
592
+
593
+ // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
594
+ #pragma unroll
595
+ for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
596
+ #ifdef FAST_FP16_AVAILABLE
597
+ half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
598
+ #else
599
+ float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
600
+ #endif // FAST_FP16_AVAILABLE
601
+
602
+ #pragma unroll
603
+ for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
604
+ const int jc = jc0 + jc1;
605
+
606
+ const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
607
+ KQ_max[jc] = KQ_max_new[jc];
608
+
609
+ float KQ_sum_add = 0.0f;
610
+ #pragma unroll
611
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
612
+ const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
613
+ expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
614
+ KQ_sum_add += val;
615
+ tmp[i0/(np*warp_size)][jc1] = val;
616
+ }
617
+ KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
618
+
619
+ #ifdef FAST_FP16_AVAILABLE
620
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
621
+ #pragma unroll
622
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
623
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
624
+ }
625
+ #else
626
+ #pragma unroll
627
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
628
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
629
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
630
+ }
631
+ #endif // FAST_FP16_AVAILABLE
632
+ }
633
+
634
+ #pragma unroll
635
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
636
+ const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
637
+
638
+ ggml_cuda_memcpy_1<sizeof(tmp[0])>(
639
+ KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
640
+ tmp[i0/(np*warp_size)]);
641
+ }
642
+ }
643
+
644
+ // VKQ = V @ KQ matrix multiplication:
645
+ static_assert(DV <= DKQ, "bad DV");
646
+ static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
647
+ constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
648
+ static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
649
+ static_assert(nbatch_V % np == 0, "bad nbatch_V");
650
+ #pragma unroll
651
+ for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
652
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
653
+ (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
654
+ __syncthreads();
655
+
656
+ #ifdef FAST_FP16_AVAILABLE
657
+ #pragma unroll
658
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
659
+ half2 V_k[(DVp/2)/warp_size];
660
+ half2 KQ_k[cpw];
661
+
662
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
663
+ #pragma unroll
664
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
665
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
666
+ }
667
+ #pragma unroll
668
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
669
+ const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
670
+
671
+ half tmp[KQ_cs];
672
+ ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
673
+ &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
674
+ #pragma unroll
675
+ for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
676
+ KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
677
+ }
678
+ }
679
+
680
+ #pragma unroll
681
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
682
+ #pragma unroll
683
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
684
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
685
+ }
686
+ }
687
+ }
688
+ #else
689
+ #pragma unroll
690
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
691
+ float2 V_k[(DVp/2)/warp_size];
692
+ float KQ_k[cpw];
693
+
694
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
695
+ #pragma unroll
696
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
697
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
698
+ }
699
+ #pragma unroll
700
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
701
+ const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
702
+
703
+ ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
704
+ &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
705
+ }
706
+
707
+ #pragma unroll
708
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
709
+ #pragma unroll
710
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
711
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
712
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
713
+ }
714
+ }
715
+ }
716
+ #endif // FAST_FP16_AVAILABLE
717
+
718
+ __syncthreads();
719
+ }
720
+ }
721
+
722
+ template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
723
+ __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
724
+ static __global__ void flash_attn_tile(
725
+ const char * __restrict__ Q,
726
+ const char * __restrict__ K,
727
+ const char * __restrict__ V,
728
+ const char * __restrict__ mask,
729
+ const char * __restrict__ sinks,
730
+ const int * __restrict__ KV_max,
731
+ float * __restrict__ dst,
732
+ float2 * __restrict__ dst_meta,
733
+ const float scale,
734
+ const float max_bias,
735
+ const float m0,
736
+ const float m1,
737
+ const uint32_t n_head_log2,
738
+ const float logit_softcap,
739
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
740
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
741
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
742
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
743
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
744
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
745
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
746
+ #ifdef FLASH_ATTN_AVAILABLE
747
+
748
+ // Skip unused kernel variants for faster compilation:
749
+
750
+ if (
751
+ #ifdef GGML_USE_WMMA_FATTN
752
+ (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
753
+ #endif // GGML_USE_WMMA_FATTN
754
+ (use_logit_softcap && !(DV == 128 || DV == 256))
755
+ ) {
756
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
757
+ max_bias, m0, m1, n_head_log2, logit_softcap,
758
+ ne00, ne01, ne02, ne03,
759
+ nb01, nb02, nb03,
760
+ ne10, ne11, ne12, ne13,
761
+ nb11, nb12, nb13,
762
+ nb21, nb22, nb23,
763
+ ne31, ne32, ne33,
764
+ nb31, nb32, nb33);
765
+ NO_DEVICE_CODE;
766
+ return;
767
+ }
768
+
769
+ static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
770
+
771
+ constexpr int ncols = ncols1*ncols2;
772
+ constexpr int warp_size = 32;
773
+ constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
774
+ constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
775
+ constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
776
+
777
+ // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
778
+
779
+ const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
780
+
781
+ const int sequence = blockIdx.z / (ne02/ncols2);
782
+ const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
783
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
784
+ const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0);
785
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
786
+ const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
787
+
788
+ const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr;
789
+
790
+ const int stride_K2 = nb11 / sizeof(half2);
791
+ const int stride_V2 = nb21 / sizeof(half2);
792
+ const int stride_mask = nb31 / sizeof(half);
793
+
794
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
795
+
796
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
797
+ constexpr int cpy_ne = cpy_nb / 4;
798
+
799
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
800
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
801
+ static_assert(cpw == 1 || np == 1, "bad cpw / np");
802
+ static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
803
+
804
+ constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
805
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
806
+
807
+ // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
808
+ // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
809
+ // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
810
+ // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
811
+ // VKQ == Accumulators in registers for the final VKQ result.
812
+ #ifdef FAST_FP16_AVAILABLE
813
+ __shared__ half2 Q_tmp[ncols * DKQ/2];
814
+ __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
815
+ __shared__ half KQ[ncols * nbatch_fa];
816
+ half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
817
+ #else
818
+ __shared__ float Q_tmp[ncols * DKQ];
819
+ __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
820
+ __shared__ float KQ[ncols * nbatch_fa];
821
+ float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
822
+ #endif // FAST_FP16_AVAILABLE
823
+
824
+ float KQ_max[cpw];
825
+ #pragma unroll
826
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
827
+ KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
828
+ }
829
+ float KQ_sum[cpw] = {0.0f};
830
+
831
+ // Load Q data, convert to FP16 if fast:
832
+ #pragma unroll
833
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
834
+ const int jc = jc0 + (threadIdx.y / np)*cpw;
835
+
836
+ const int j = jc / ncols2;
837
+ const int c = jc % ncols2;
838
+
839
+ constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
840
+
841
+ #pragma unroll
842
+ for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
843
+ if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
844
+ float tmp_f[cpy_ne_D] = {0.0f};
845
+ if (ncols1 == 1 || col_Q_0 + j < ne01) {
846
+ ggml_cuda_memcpy_1<sizeof(tmp_f)>
847
+ (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float))
848
+ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
849
+ }
850
+
851
+ #pragma unroll
852
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
853
+ tmp_f[i1] *= scale;
854
+ }
855
+
856
+ #ifdef FAST_FP16_AVAILABLE
857
+ half2 tmp_h2[cpy_ne_D/2];
858
+ #pragma unroll
859
+ for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
860
+ tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
861
+ }
862
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
863
+ &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
864
+ tmp_h2);
865
+ #else
866
+ ggml_cuda_memcpy_1<sizeof(tmp_f)>(
867
+ &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D],
868
+ tmp_f);
869
+ #endif // FAST_FP16_AVAILABLE
870
+ }
871
+ }
872
+ }
873
+
874
+ __syncthreads();
875
+
876
+ // Main loop over KV cache:
877
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
878
+ if (ncols2 == 1) {
879
+ // Branch with out-of-bounds checks.
880
+ int k_VKQ_0 = blockIdx.y*nbatch_fa;
881
+ while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
882
+ constexpr bool oob_check = false;
883
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
884
+ (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
885
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
886
+ k_VKQ_0 += gridDim.y*nbatch_fa;
887
+ }
888
+ if (k_VKQ_0 < k_VKQ_max) {
889
+ constexpr bool oob_check = true;
890
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
891
+ (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
892
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
893
+ }
894
+ } else {
895
+ // Branch without out-of-bounds checks.
896
+ for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
897
+ constexpr bool oob_check = false;
898
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
899
+ (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp,
900
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max);
901
+ }
902
+ }
903
+
904
+ #pragma unroll
905
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
906
+ KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
907
+ }
908
+
909
+ if constexpr (np > 1) {
910
+ static_assert(cpw == 1, "bad cpw");
911
+ static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
912
+
913
+ #ifdef FAST_FP16_AVAILABLE
914
+ half2 * VKQ_combine = (half2 *) KV_tmp;
915
+ #else
916
+ float * VKQ_combine = (float *) KV_tmp;
917
+ #endif // FAST_FP16_AVAILABLE
918
+ float * KQ_sum_combine = (float *) Q_tmp;
919
+
920
+ if (threadIdx.y % np != 0) {
921
+ #ifdef FAST_FP16_AVAILABLE
922
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
923
+ #pragma unroll
924
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
925
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
926
+ }
927
+ #else
928
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
929
+ #pragma unroll
930
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
931
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
932
+ &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
933
+ }
934
+ #endif // FAST_FP16_AVAILABLE
935
+
936
+ if (threadIdx.x == 0) {
937
+ KQ_sum_combine[threadIdx.y] = KQ_sum[0];
938
+ }
939
+
940
+ return;
941
+ }
942
+
943
+ __syncthreads();
944
+
945
+ #pragma unroll
946
+ for (int ip = 1; ip < np; ++ip) {
947
+ #ifdef FAST_FP16_AVAILABLE
948
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
949
+ #pragma unroll
950
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
951
+ half2 tmp[cpy_ne_D];
952
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
953
+ #pragma unroll
954
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
955
+ VKQ[i0/warp_size + i1] += tmp[i1];
956
+ }
957
+ }
958
+ #else
959
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
960
+ #pragma unroll
961
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
962
+ float tmp[cpy_ne_D];
963
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
964
+ #pragma unroll
965
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
966
+ ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
967
+ }
968
+ }
969
+ #endif // FAST_FP16_AVAILABLE
970
+
971
+ KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
972
+ }
973
+ }
974
+
975
+ // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
976
+ if (sinks && blockIdx.y == 0) {
977
+ #pragma unroll
978
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
979
+ const int jc = jc0 + (threadIdx.y/np)*cpw;
980
+ const float sink = ((const float *) sinks)[head0 + jc % ncols2];
981
+
982
+ float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
983
+ const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
984
+ KQ_max[jc0] = KQ_max_new_j;
985
+
986
+ const float val = expf(sink - KQ_max[jc0]);
987
+ KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
988
+
989
+ #ifdef FAST_FP16_AVAILABLE
990
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
991
+ #pragma unroll
992
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
993
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
994
+ }
995
+ #else
996
+ #pragma unroll
997
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
998
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
999
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
1000
+ }
1001
+ #endif // FAST_FP16_AVAILABLE
1002
+ }
1003
+ }
1004
+
1005
+ // Write back results:
1006
+ #pragma unroll
1007
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
1008
+ const int jc = jc0 + (threadIdx.y/np)*cpw;
1009
+
1010
+ const int j = jc / ncols2;
1011
+ const int c = jc % ncols2;
1012
+
1013
+ if (ncols1 > 1 && col_Q_0 + j >= ne01) {
1014
+ return;
1015
+ }
1016
+
1017
+ const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
1018
+
1019
+ const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1020
+
1021
+ #ifdef FAST_FP16_AVAILABLE
1022
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
1023
+ #pragma unroll
1024
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
1025
+ float2 tmp[cpy_ne_D];
1026
+ #pragma unroll
1027
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
1028
+ tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1029
+ tmp[i1].x *= scale;
1030
+ tmp[i1].y *= scale;
1031
+ }
1032
+ if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
1033
+ ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
1034
+ }
1035
+ }
1036
+ #else
1037
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
1038
+ #pragma unroll
1039
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
1040
+ if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1041
+ #pragma unroll
1042
+ for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1043
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1044
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1045
+ }
1046
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
1047
+ &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
1048
+ &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
1049
+ }
1050
+ }
1051
+ #endif // FAST_FP16_AVAILABLE
1052
+
1053
+ if (gridDim.y != 1 && threadIdx.x == 0) {
1054
+ dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
1055
+ }
1056
+ }
1057
+ #else
1058
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1059
+ max_bias, m0, m1, n_head_log2, logit_softcap,
1060
+ ne00, ne01, ne02, ne03,
1061
+ nb01, nb02, nb03,
1062
+ ne10, ne11, ne12, ne13,
1063
+ nb11, nb12, nb13,
1064
+ nb21, nb22, nb23,
1065
+ ne31, ne32, ne33,
1066
+ nb31, nb32, nb33);
1067
+ NO_DEVICE_CODE;
1068
+ #endif // FLASH_ATTN_AVAILABLE
1069
+ }
1070
+
1071
+ template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
1072
+ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1073
+ const ggml_tensor * Q = dst->src[0];
1074
+
1075
+ const int id = ggml_cuda_get_device();
1076
+ const int cc = ggml_cuda_info().devices[id].cc;
1077
+ const int warp_size = 32;
1078
+
1079
+ constexpr size_t nbytes_shared = 0;
1080
+
1081
+ #ifdef GGML_USE_HIP
1082
+ if constexpr (DV <= 128) {
1083
+ if (Q->ne[1] > 32/ncols2) {
1084
+ constexpr int cols_per_block = 64;
1085
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1086
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1087
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1088
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1089
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1090
+ return;
1091
+ }
1092
+ }
1093
+ #endif // GGML_USE_HIP
1094
+
1095
+ #ifndef GGML_USE_HIP
1096
+ if constexpr (DV <= 256)
1097
+ #endif // GGML_USE_HIP
1098
+ {
1099
+ if (Q->ne[1] > 16/ncols2) {
1100
+ constexpr int cols_per_block = 32;
1101
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1102
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1103
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1104
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1105
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1106
+ return;
1107
+ }
1108
+ }
1109
+
1110
+ if (Q->ne[1] > 8/ncols2) {
1111
+ constexpr int cols_per_block = 16;
1112
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1113
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1114
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1115
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1116
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1117
+ return;
1118
+ }
1119
+
1120
+ if constexpr (ncols2 <= 8) {
1121
+ if (Q->ne[1] > 4/ncols2) {
1122
+ constexpr int cols_per_block = 8;
1123
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1124
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1125
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1126
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1127
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1128
+ return;
1129
+ }
1130
+ }
1131
+
1132
+ if constexpr (ncols2 <= 4) {
1133
+ if (Q->ne[1] > 2/ncols2) {
1134
+ constexpr int cols_per_block = 4;
1135
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1136
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1137
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1138
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1139
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1140
+ return;
1141
+ }
1142
+ }
1143
+
1144
+ if constexpr (ncols2 <= 2) {
1145
+ constexpr int cols_per_block = 2;
1146
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
1147
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
1148
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
1149
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
1150
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
1151
+ return;
1152
+ }
1153
+
1154
+ GGML_ABORT("fatal error");
1155
+ }
1156
+
1157
+ template <int DKQ, int DV, bool use_logit_softcap>
1158
+ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1159
+ const ggml_tensor * KQV = dst;
1160
+ const ggml_tensor * Q = dst->src[0];
1161
+ const ggml_tensor * K = dst->src[1];
1162
+ const ggml_tensor * mask = dst->src[3];
1163
+
1164
+ float max_bias = 0.0f;
1165
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
1166
+
1167
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
1168
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
1169
+
1170
+ const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
1171
+ const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
1172
+ const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
1173
+
1174
+ if constexpr (DV == 512) {
1175
+ if (use_gqa_opt && gqa_ratio % 16 == 0) {
1176
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
1177
+ return;
1178
+ }
1179
+ }
1180
+
1181
+ if constexpr (DV <= 256) {
1182
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
1183
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
1184
+ return;
1185
+ }
1186
+
1187
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
1188
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
1189
+ return;
1190
+ }
1191
+
1192
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
1193
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
1194
+ return;
1195
+ }
1196
+
1197
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
1198
+ return;
1199
+ }
1200
+ GGML_ABORT("fatal error");
1201
+ }
1202
+
1203
+ template <int DKQ, int DV>
1204
+ void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1205
+ const ggml_tensor * KQV = dst;
1206
+
1207
+ float logit_softcap;
1208
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1209
+
1210
+ if (logit_softcap == 0.0f) {
1211
+ constexpr bool use_logit_softcap = false;
1212
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1213
+ } else {
1214
+ constexpr bool use_logit_softcap = true;
1215
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
1216
+ }
1217
+ }
1218
+
1219
+ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1220
+
1221
+ #define DECL_FATTN_TILE_CASE(DKQ, DV) \
1222
+ template void ggml_cuda_flash_attn_ext_tile_case \
1223
+ <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
1224
+
1225
+ extern DECL_FATTN_TILE_CASE( 40, 40);
1226
+ extern DECL_FATTN_TILE_CASE( 64, 64);
1227
+ extern DECL_FATTN_TILE_CASE( 72, 72);
1228
+ extern DECL_FATTN_TILE_CASE( 80, 80);
1229
+ extern DECL_FATTN_TILE_CASE( 96, 96);
1230
+ extern DECL_FATTN_TILE_CASE(112, 112);
1231
+ extern DECL_FATTN_TILE_CASE(128, 128);
1232
+ extern DECL_FATTN_TILE_CASE(256, 256);
1233
+ extern DECL_FATTN_TILE_CASE(576, 512);