@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -6,7 +6,7 @@
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
11
  static __global__ void flash_attn_tile_ext_f32(
12
12
  const char * __restrict__ Q,
@@ -21,29 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
21
21
  const float m1,
22
22
  const uint32_t n_head_log2,
23
23
  const float logit_softcap,
24
- const int ne00,
25
- const int ne01,
26
- const int ne02,
27
- const int ne03,
28
- const int ne10,
29
- const int ne11,
30
- const int ne12,
31
- const int ne13,
32
- const int ne31,
33
- const int nb31,
34
- const int nb01,
35
- const int nb02,
36
- const int nb03,
37
- const int nb11,
38
- const int nb12,
39
- const int nb13,
40
- const int nb21,
41
- const int nb22,
42
- const int nb23,
43
- const int ne0,
44
- const int ne1,
45
- const int ne2,
46
- const int ne3) {
24
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
26
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
28
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
29
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
30
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
47
31
  #ifdef FLASH_ATTN_AVAILABLE
48
32
 
49
33
  // Skip unused kernel variants for faster compilation:
@@ -53,17 +37,16 @@ static __global__ void flash_attn_tile_ext_f32(
53
37
  #endif // FP16_MMA_AVAILABLE
54
38
  if (use_logit_softcap && !(D == 128 || D == 256)) {
55
39
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
56
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
57
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
40
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
41
+ GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
58
42
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
59
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
60
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
61
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
62
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
63
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
64
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
65
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
66
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
43
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
44
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
45
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
46
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
47
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
48
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
49
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
67
50
  NO_DEVICE_CODE;
68
51
  return;
69
52
  }
@@ -72,15 +55,17 @@ static __global__ void flash_attn_tile_ext_f32(
72
55
 
73
56
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
74
57
 
58
+ const int sequence = blockIdx.z / ne02;
59
+ const int head = blockIdx.z - sequence*ne02;
75
60
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
76
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
77
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
78
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
79
- const half * maskh = (const half *) mask + ne11*ic0;
61
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
62
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
63
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
64
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
80
65
 
81
66
  const int stride_KV2 = nb11 / sizeof(half2);
82
67
 
83
- const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
68
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
84
69
 
85
70
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
86
71
 
@@ -129,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
129
114
 
130
115
  #pragma unroll
131
116
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
132
- const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
117
+ const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
133
118
  KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
134
119
  KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
135
120
  }
@@ -225,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
225
210
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
226
211
  const int i = i0 + threadIdx.x;
227
212
 
228
- KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
229
- KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
213
+ const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
214
+ KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
215
+ KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
230
216
  }
231
217
  }
232
218
 
@@ -263,6 +249,8 @@ static __global__ void flash_attn_tile_ext_f32(
263
249
  __syncthreads();
264
250
  }
265
251
 
252
+ float2 * dst2 = (float2 *) dst;
253
+
266
254
  #pragma unroll
267
255
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
268
256
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -274,37 +262,36 @@ static __global__ void flash_attn_tile_ext_f32(
274
262
  float kqsum_j = kqsum[j_VKQ_0/nwarps];
275
263
  kqsum_j = warp_reduce_sum(kqsum_j);
276
264
 
265
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
266
+
277
267
  #pragma unroll
278
- for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
279
- const int i0 = i00 + 2*threadIdx.x;
268
+ for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
269
+ const int i0 = i00 + threadIdx.x;
280
270
 
281
- float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
271
+ float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
282
272
  if (gridDim.y == 1) {
283
273
  dst_val.x /= kqsum_j;
284
274
  dst_val.y /= kqsum_j;
285
275
  }
286
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
287
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
288
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
276
+ dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
289
277
  }
290
278
 
291
279
  if (gridDim.y != 1 && threadIdx.x == 0) {
292
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
280
+ dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
293
281
  }
294
282
  }
295
283
  #else
296
284
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
297
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
298
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
285
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
286
+ GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
299
287
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
300
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
301
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
302
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
303
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
304
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
305
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
306
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
307
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
288
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
289
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
290
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
291
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
292
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
293
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
294
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
308
295
  NO_DEVICE_CODE;
309
296
  #endif // FLASH_ATTN_AVAILABLE
310
297
  }
@@ -18,29 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
18
18
  const float m1,
19
19
  const uint32_t n_head_log2,
20
20
  const float logit_softcap,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int nb31,
31
- const int nb01,
32
- const int nb02,
33
- const int nb03,
34
- const int nb11,
35
- const int nb12,
36
- const int nb13,
37
- const int nb21,
38
- const int nb22,
39
- const int nb23,
40
- const int ne0,
41
- const int ne1,
42
- const int ne2,
43
- const int ne3) {
21
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
22
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
23
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
24
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
25
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
26
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
27
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
44
28
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
45
29
 
46
30
  // Skip unused kernel variants for faster compilation:
@@ -63,14 +47,16 @@ static __global__ void flash_attn_vec_ext_f16(
63
47
 
64
48
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
65
49
 
50
+ const int sequence = blockIdx.z / ne02;
51
+ const int head = blockIdx.z - sequence*ne02;
66
52
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
67
- Q += nb02* blockIdx.z + nb01*ic0;
68
- K += nb12*(blockIdx.z / gqa_ratio);
69
- V += nb22*(blockIdx.z / gqa_ratio);
53
+ Q += nb03*sequence + nb02* head + nb01*ic0;
54
+ K += nb13*sequence + nb12*(head / gqa_ratio);
55
+ V += nb23*sequence + nb22*(head / gqa_ratio);
70
56
 
71
- const half * maskh = (const half *) mask + ne11*ic0;
57
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
72
58
 
73
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
59
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
74
60
  const half slopeh = __float2half(slopef);
75
61
 
76
62
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -185,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
185
171
 
186
172
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
187
173
 
174
+ K += blockIdx.y*D * nb11;
175
+ V += blockIdx.y*D * nb21;
176
+ maskh += blockIdx.y*D;
188
177
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
189
178
  // Calculate KQ tile and keep track of new maximum KQ values:
190
179
 
191
180
  if (mask) {
192
181
  #pragma unroll
193
182
  for (int j = 0; j < ncols; ++j) {
194
- maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
183
+ maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
195
184
  }
196
185
 
197
186
  __syncthreads();
@@ -238,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
238
227
 
239
228
  #pragma unroll
240
229
  for (int j = 0; j < ncols; ++j) {
241
- half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
230
+ half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
242
231
  sum = warp_reduce_sum((float)sum);
243
232
 
244
233
  if (use_logit_softcap) {
@@ -294,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
294
283
  }
295
284
 
296
285
  half2 V_k;
297
- reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
298
- reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
286
+ reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
287
+ reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
299
288
  #pragma unroll
300
289
  for (int j = 0; j < ncols; ++j) {
301
290
  VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
302
291
  }
303
292
  }
304
293
 
294
+ K += gridDim.y*D * nb11;
295
+ V += gridDim.y*D * nb21;
296
+ maskh += gridDim.y*D;
297
+
305
298
  __syncthreads();
306
299
  }
307
300
 
@@ -328,26 +321,24 @@ static __global__ void flash_attn_vec_ext_f16(
328
321
  if (gridDim.y == 1) {
329
322
  dst_val /= kqsum[j_VKQ];
330
323
  }
331
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
332
- dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
324
+ dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
333
325
  }
334
326
 
335
327
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
336
- dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
328
+ dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
337
329
  }
338
330
  #else
339
331
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
340
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
341
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
332
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
333
+ GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
342
334
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
343
- GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
344
- GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
345
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
346
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347
- GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
348
- GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
349
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
350
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
335
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
336
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
337
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
338
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
339
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
340
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
341
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
351
342
  NO_DEVICE_CODE;
352
343
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
353
344
  }
@@ -18,29 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
18
18
  const float m1,
19
19
  const uint32_t n_head_log2,
20
20
  const float logit_softcap,
21
- const int ne00,
22
- const int ne01,
23
- const int ne02,
24
- const int ne03,
25
- const int ne10,
26
- const int ne11,
27
- const int ne12,
28
- const int ne13,
29
- const int ne31,
30
- const int nb31,
31
- const int nb01,
32
- const int nb02,
33
- const int nb03,
34
- const int nb11,
35
- const int nb12,
36
- const int nb13,
37
- const int nb21,
38
- const int nb22,
39
- const int nb23,
40
- const int ne0,
41
- const int ne1,
42
- const int ne2,
43
- const int ne3) {
21
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
22
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
23
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
24
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
25
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
26
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
27
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
44
28
  #ifdef FLASH_ATTN_AVAILABLE
45
29
 
46
30
  // Skip unused kernel variants for faster compilation:
@@ -51,12 +35,11 @@ static __global__ void flash_attn_vec_ext_f32(
51
35
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
52
36
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
53
37
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
54
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
55
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
38
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
39
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
56
40
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
57
41
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
58
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
59
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
42
+ GGML_UNUSED(nb23);
60
43
  NO_DEVICE_CODE;
61
44
  return;
62
45
  }
@@ -75,13 +58,16 @@ static __global__ void flash_attn_vec_ext_f32(
75
58
 
76
59
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
77
60
 
61
+ const int sequence = blockIdx.z / ne02;
62
+ const int head = blockIdx.z - sequence*ne02;
78
63
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
79
- Q += nb02* blockIdx.z + nb01*ic0;
80
- K += nb12*(blockIdx.z / gqa_ratio);
81
- V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
82
- const half * maskh = (const half *) mask + ne11*ic0;
64
+ Q += nb03*sequence + nb02* head + nb01*ic0;
65
+ K += nb13*sequence + nb12*(head / gqa_ratio);
66
+ V += nb23*sequence + nb22*(head / gqa_ratio);
83
67
 
84
- const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
68
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
69
+
70
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
85
71
 
86
72
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
87
73
  constexpr int nwarps = D / WARP_SIZE;
@@ -191,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
191
177
 
192
178
  float VKQ[ncols] = {0.0f};
193
179
 
180
+ K += blockIdx.y*D * nb11;
181
+ V += blockIdx.y*D * nb21;
182
+ maskh += blockIdx.y*D;
194
183
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
195
184
  // Calculate KQ tile and keep track of new maximum KQ values:
196
185
 
197
186
  if (mask) {
198
187
  #pragma unroll
199
188
  for (int j = 0; j < ncols; ++j) {
200
- maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
189
+ maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
201
190
  }
202
191
 
203
192
  __syncthreads();
@@ -239,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
239
228
 
240
229
  #pragma unroll
241
230
  for (int j = 0; j < ncols; ++j) {
242
- float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
231
+ float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
243
232
  sum = warp_reduce_sum(sum);
244
233
 
245
234
  if (use_logit_softcap) {
@@ -290,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
290
279
  break;
291
280
  }
292
281
 
293
- const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
282
+ const float V_ki = dequantize_1_v(V + k*nb21, tid);
294
283
  #pragma unroll
295
284
  for (int j = 0; j < ncols; ++j) {
296
285
  VKQ[j] += V_ki*KQ[j*D + k];
297
286
  }
298
287
  }
299
288
 
289
+ K += gridDim.y*D * nb11;
290
+ V += gridDim.y*D * nb21;
291
+ maskh += gridDim.y*D;
292
+
300
293
  __syncthreads();
301
294
  }
302
295
 
@@ -323,24 +316,24 @@ static __global__ void flash_attn_vec_ext_f32(
323
316
  if (gridDim.y == 1) {
324
317
  dst_val /= kqsum[j_VKQ];
325
318
  }
326
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
327
- dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
319
+ dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
328
320
  }
329
321
 
330
322
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
331
- dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
323
+ dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
332
324
  }
333
325
  #else
334
326
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
335
327
  GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
336
328
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
337
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
338
- GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
339
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
340
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
341
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
342
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
343
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
329
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
330
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
331
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
332
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
333
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
334
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
335
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
336
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
344
337
  NO_DEVICE_CODE;
345
338
  #endif // FLASH_ATTN_AVAILABLE
346
339
  }
@@ -37,29 +37,13 @@ static __global__ void flash_attn_ext_f16(
37
37
  const float m1,
38
38
  const uint32_t n_head_log2,
39
39
  const float logit_softcap,
40
- const int ne00,
41
- const int ne01,
42
- const int ne02,
43
- const int ne03,
44
- const int ne10,
45
- const int ne11,
46
- const int ne12,
47
- const int ne13,
48
- const int ne31,
49
- const int nb31,
50
- const int nb01,
51
- const int nb02,
52
- const int nb03,
53
- const int nb11,
54
- const int nb12,
55
- const int nb13,
56
- const int nb21,
57
- const int nb22,
58
- const int nb23,
59
- const int ne0,
60
- const int ne1,
61
- const int ne2,
62
- const int ne3) {
40
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
41
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
42
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
43
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
44
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
45
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
46
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
63
47
  #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
64
48
  // Skip unused kernel variants for faster compilation:
65
49
  if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -93,17 +77,19 @@ static __global__ void flash_attn_ext_f16(
93
77
  constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
94
78
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
95
79
 
80
+ const int sequence = blockIdx.z / ne02;
81
+ const int head = blockIdx.z - sequence*ne02;
96
82
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
97
- const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
98
- const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
99
- const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
100
- const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
101
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
83
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
84
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
85
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
86
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
87
+ const half2 * mask2 = (const half2 *) maskh;
102
88
 
103
89
  const int stride_Q = nb01 / sizeof(float);
104
90
  const int stride_KV = nb11 / sizeof(half);
105
91
 
106
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
92
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
107
93
  const half slopeh = __float2half(slopef);
108
94
  const half2 slope2 = make_half2(slopef, slopef);
109
95
 
@@ -191,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
191
177
  #pragma unroll
192
178
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
193
179
  frag_a_K K_a;
194
- wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
180
+ wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
195
181
  #pragma unroll
196
182
  for (int j = 0; j < ncols/frag_n; ++j) {
197
183
  wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -338,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
338
324
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
339
325
 
340
326
  frag_a_V v_a;
341
- wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
327
+ wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
342
328
  #pragma unroll
343
329
  for (int j = 0; j < ncols/frag_n; ++j) {
344
330
  wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -398,7 +384,6 @@ static __global__ void flash_attn_ext_f16(
398
384
  if (ic0 + j_VKQ >= ne01) {
399
385
  return;
400
386
  }
401
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
402
387
 
403
388
  float KQ_rowsum_j;
404
389
  if (std::is_same<KQ_acc_t, float>::value) {
@@ -407,6 +392,8 @@ static __global__ void flash_attn_ext_f16(
407
392
  KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
408
393
  }
409
394
 
395
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
396
+
410
397
  #pragma unroll
411
398
  for (int i0 = 0; i0 < D; i0 += warp_size) {
412
399
  const int i = i0 + threadIdx.x;
@@ -417,7 +404,7 @@ static __global__ void flash_attn_ext_f16(
417
404
  if (gridDim.y == 1) {
418
405
  dst_val /= KQ_rowsum_j;
419
406
  }
420
- dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
407
+ dst[j_dst_unrolled*D + i] = dst_val;
421
408
  }
422
409
 
423
410
  if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -431,7 +418,7 @@ static __global__ void flash_attn_ext_f16(
431
418
  dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
432
419
  }
433
420
  dst_meta_val.y = KQ_rowsum_j;
434
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
421
+ dst_meta[j_dst_unrolled] = dst_meta_val;
435
422
  }
436
423
  #else
437
424
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -440,10 +427,10 @@ static __global__ void flash_attn_ext_f16(
440
427
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
441
428
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
442
429
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
443
- GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
430
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
431
+ GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
444
432
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
445
433
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
446
- GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
447
434
  NO_DEVICE_CODE;
448
435
  #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
449
436
  }
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
280
280
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
281
281
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
282
282
 
283
- if (GGML_CUDA_CC_IS_AMD(cc)) {
284
283
  #if defined(GGML_HIP_ROCWMMA_FATTN)
285
- if (fp16_mma_available(cc)) {
286
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
287
- return;
288
- }
289
- #endif // defined(GGML_HIP_ROCWMMA_FATTN)
290
-
291
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
292
- if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
293
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
294
- } else {
295
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
296
- }
284
+ if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
285
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
297
286
  return;
298
287
  }
288
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN)
299
289
 
300
290
  if (!fast_fp16_available(cc)) {
301
291
  if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
168
168
  get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
169
169
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
170
170
  break;
171
+ case GGML_TYPE_I32:
172
+ get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
173
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
174
+ break;
171
175
  case GGML_TYPE_BF16:
172
176
  get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
173
177
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
210
214
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
211
215
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
212
216
  break;
217
+ case GGML_TYPE_I32:
218
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
219
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
220
+ break;
213
221
  case GGML_TYPE_F16:
214
222
  ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
215
223
  ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);