@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
123
123
  ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124
124
 
125
125
  if (nbytes_shared <= smpbo) {
126
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
127
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
128
- if (!shared_memory_limit_raised[id]) {
129
- CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
130
- shared_memory_limit_raised[id] = true;
131
- }
132
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
126
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
133
127
  cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
134
128
  } else {
135
129
  cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
175
169
  const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
176
170
 
177
171
  if (nbytes_shared <= smpbo) {
178
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
180
- if (!shared_memory_limit_raised[id]) {
181
- CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
182
- shared_memory_limit_raised[id] = true;
183
- }
184
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
172
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
185
173
  cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
186
174
  } else {
187
175
  cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
@@ -23,29 +23,13 @@ typedef void (* fattn_kernel_t)(
23
23
  const float m1,
24
24
  const uint32_t n_head_log2,
25
25
  const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int nb31,
36
- const int nb01,
37
- const int nb02,
38
- const int nb03,
39
- const int nb11,
40
- const int nb12,
41
- const int nb13,
42
- const int nb21,
43
- const int nb22,
44
- const int nb23,
45
- const int ne0,
46
- const int ne1,
47
- const int ne2,
48
- const int ne3);
26
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
27
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
28
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
29
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
30
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
31
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
32
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
49
33
 
50
34
  typedef half (*vec_dot_KQ_f16_t)(
51
35
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -519,7 +503,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
519
503
  template<int D, int ncols1, int ncols2> // D == head size
520
504
  __launch_bounds__(D, 1)
521
505
  static __global__ void flash_attn_stream_k_fixup(
522
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
506
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
523
507
  constexpr int ncols = ncols1*ncols2;
524
508
 
525
509
  const int bidx0 = blockIdx.x;
@@ -533,8 +517,8 @@ static __global__ void flash_attn_stream_k_fixup(
533
517
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
534
518
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
535
519
 
536
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
537
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
520
+ const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
521
+ const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
538
522
 
539
523
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
540
524
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -543,14 +527,15 @@ static __global__ void flash_attn_stream_k_fixup(
543
527
  return;
544
528
  }
545
529
 
546
- const int channel = kbc0 / (iter_k*iter_j);
547
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
530
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
531
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
532
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
548
533
 
549
534
  if (jt*ncols1 + j >= ne01) {
550
535
  return;
551
536
  }
552
537
 
553
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
538
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
554
539
 
555
540
  // Load the partial result that needs a fixup:
556
541
  float dst_val = 0.0f;
@@ -569,7 +554,7 @@ static __global__ void flash_attn_stream_k_fixup(
569
554
  int bidx = bidx0 - 1;
570
555
  int kbc_stop = kbc0;
571
556
  while(true) {
572
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
557
+ const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
573
558
  if (kbc == kbc_stop) { // Did not have any data.
574
559
  bidx--;
575
560
  kbc_stop = kbc;
@@ -615,16 +600,31 @@ static __global__ void flash_attn_combine_results(
615
600
  const float2 * __restrict__ VKQ_meta,
616
601
  float * __restrict__ dst,
617
602
  const int parallel_blocks) {
618
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
619
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
620
- dst += D * gridDim.z*blockIdx.x;
603
+ // Dimension 0: threadIdx.x
604
+ // Dimension 1: blockIdx.x
605
+ // Dimension 2: blockIdx.y
606
+ // Dimension 3: blockIdx.z
607
+ // Memory layout is permuted with [0, 2, 1, 3]
608
+
609
+ const int ne01 = gridDim.x;
610
+ const int ne02 = gridDim.y;
611
+
612
+ const int col = blockIdx.x;
613
+ const int head = blockIdx.y;
614
+ const int sequence = blockIdx.z;
615
+
616
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
617
+
618
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
619
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
620
+ dst += j_dst_unrolled * D;
621
621
 
622
622
  const int tid = threadIdx.x;
623
623
  __builtin_assume(tid < D);
624
624
 
625
625
  extern __shared__ float2 meta[];
626
626
  for (int i = tid; i < 2*parallel_blocks; i += D) {
627
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
627
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
628
628
  }
629
629
 
630
630
  __syncthreads();
@@ -642,11 +642,11 @@ static __global__ void flash_attn_combine_results(
642
642
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
643
643
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
644
644
 
645
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
645
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
646
646
  VKQ_denominator += KQ_max_scale * meta[l].y;
647
647
  }
648
648
 
649
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
649
+ dst[tid] = VKQ_numerator / VKQ_denominator;
650
650
  }
651
651
 
652
652
  [[noreturn]]
@@ -703,8 +703,6 @@ void launch_fattn(
703
703
 
704
704
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
705
705
 
706
- GGML_ASSERT(Q->ne[3] == 1);
707
-
708
706
  ggml_cuda_pool & pool = ctx.pool();
709
707
  cudaStream_t main_stream = ctx.stream();
710
708
  const int id = ggml_cuda_get_device();
@@ -727,33 +725,58 @@ void launch_fattn(
727
725
  size_t nb23 = V ? V->nb[3] : nb13;
728
726
 
729
727
  if (need_f16_K && K->type != GGML_TYPE_F16) {
730
- GGML_ASSERT(ggml_is_contiguously_allocated(K));
731
- K_f16.alloc(ggml_nelements(K));
732
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
733
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
734
- K_data = (char *) K_f16.ptr;
735
-
736
728
  const size_t bs = ggml_blck_size(K->type);
737
729
  const size_t ts = ggml_type_size(K->type);
738
730
 
739
- nb11 = nb11*bs*sizeof(half)/ts;
740
- nb12 = nb12*bs*sizeof(half)/ts;
741
- nb13 = nb13*bs*sizeof(half)/ts;
731
+ K_f16.alloc(ggml_nelements(K));
732
+ if (ggml_is_contiguously_allocated(K)) {
733
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
734
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
735
+
736
+ nb11 = nb11*bs*sizeof(half)/ts;
737
+ nb12 = nb12*bs*sizeof(half)/ts;
738
+ nb13 = nb13*bs*sizeof(half)/ts;
739
+ } else {
740
+ GGML_ASSERT(K->nb[0] == ts);
741
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
742
+ const int64_t s01 = nb11 / ts;
743
+ const int64_t s02 = nb12 / ts;
744
+ const int64_t s03 = nb13 / ts;
745
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
746
+
747
+ nb11 = K->ne[0] * sizeof(half);
748
+ nb12 = K->ne[1] * nb11;
749
+ nb13 = K->ne[2] * nb12;
750
+ }
751
+ K_data = (char *) K_f16.ptr;
742
752
  }
743
753
 
744
754
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
745
- GGML_ASSERT(ggml_is_contiguously_allocated(V));
746
- V_f16.alloc(ggml_nelements(V));
747
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
748
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
749
- V_data = (char *) V_f16.ptr;
750
-
751
755
  const size_t bs = ggml_blck_size(V->type);
752
756
  const size_t ts = ggml_type_size(V->type);
753
757
 
754
- nb21 = nb21*bs*sizeof(half)/ts;
755
- nb22 = nb22*bs*sizeof(half)/ts;
756
- nb23 = nb23*bs*sizeof(half)/ts;
758
+ V_f16.alloc(ggml_nelements(V));
759
+ if (ggml_is_contiguously_allocated(V)) {
760
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
761
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
762
+ V_data = (char *) V_f16.ptr;
763
+
764
+ nb21 = nb21*bs*sizeof(half)/ts;
765
+ nb22 = nb22*bs*sizeof(half)/ts;
766
+ nb23 = nb23*bs*sizeof(half)/ts;
767
+ } else {
768
+ GGML_ASSERT(V->nb[0] == ts);
769
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
770
+ const int64_t s01 = nb21 / ts;
771
+ const int64_t s02 = nb22 / ts;
772
+ const int64_t s03 = nb23 / ts;
773
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
774
+
775
+ nb21 = V->ne[0] * sizeof(half);
776
+ nb22 = V->ne[1] * nb21;
777
+ nb23 = V->ne[2] * nb22;
778
+ }
779
+ V_data = (char *) V_f16.ptr;
757
780
  }
758
781
 
759
782
  int parallel_blocks = 1;
@@ -849,13 +872,11 @@ void launch_fattn(
849
872
  mask ? ((const char *) mask->data) : nullptr,
850
873
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
851
874
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
855
- Q->nb[1], Q->nb[2], Q->nb[3],
856
- nb11, nb12, nb13,
875
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
876
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
857
877
  nb21, nb22, nb23,
858
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
878
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
879
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
859
880
  );
860
881
  CUDA_CHECK(cudaGetLastError());
861
882
 
@@ -866,11 +887,11 @@ void launch_fattn(
866
887
 
867
888
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
868
889
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
869
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
890
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
870
891
  }
871
892
  } else if (parallel_blocks > 1) {
872
893
  const dim3 block_dim_combine(DV, 1, 1);
873
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
894
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
874
895
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
875
896
 
876
897
  flash_attn_combine_results<DV>
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408
408
  const int stride_K,
409
409
  const int stride_V,
410
410
  const int stride_mask,
411
- const int jt,
412
411
  half2 * const __restrict__ tile_Q,
413
412
  half2 * const __restrict__ tile_K,
414
413
  half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455
454
  cp_async_wait_all();
456
455
  __syncthreads();
457
456
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
457
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
459
458
  } else {
460
459
  constexpr bool use_cp_async = nstages == 1;
461
460
  if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471
470
  if (nstages <= 1) {
472
471
  constexpr bool use_cp_async = nstages == 1;
473
472
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
473
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
475
474
  if (use_cp_async) {
476
475
  cp_async_wait_all();
477
476
  }
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715
714
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
716
715
  }
717
716
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
717
+ (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719
718
  }
720
719
  }
721
720
 
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732
731
  if (nstages <= 1 && i0_start < reusable_cutoff) {
733
732
  constexpr bool use_cp_async = nstages == 1;
734
733
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735
- (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
734
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736
735
  if (use_cp_async) {
737
736
  cp_async_wait_all();
738
737
  }
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771
770
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
772
771
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
773
772
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
774
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
775
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773
+ GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
776
774
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
777
775
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
778
776
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920
918
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
921
919
  }
922
920
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921
+ (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924
922
  }
925
923
 
926
924
  // Iterate over ne11 == previous tokens:
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
928
926
  constexpr bool last_iter = false;
929
927
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930
928
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932
930
  }
933
931
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934
932
  constexpr bool last_iter = true;
935
933
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936
934
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
935
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
938
936
  }
939
937
 
940
938
  // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,29 +1212,13 @@ static __global__ void flash_attn_ext_f16(
1214
1212
  const float m1,
1215
1213
  const uint32_t n_head_log2,
1216
1214
  const float logit_softcap,
1217
- const int ne00,
1218
- const int ne01,
1219
- const int ne02,
1220
- const int ne03,
1221
- const int ne10,
1222
- const int ne11,
1223
- const int ne12,
1224
- const int ne13,
1225
- const int ne31,
1226
- const int nb31,
1227
- const int nb01,
1228
- const int nb02,
1229
- const int nb03,
1230
- const int nb11,
1231
- const int nb12,
1232
- const int nb13,
1233
- const int nb21,
1234
- const int nb22,
1235
- const int nb23,
1236
- const int ne0,
1237
- const int ne1,
1238
- const int ne2,
1239
- const int ne3) {
1215
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1240
1222
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1241
1223
 
1242
1224
  // Skip unused kernel variants for faster compilation:
@@ -1272,8 +1254,8 @@ static __global__ void flash_attn_ext_f16(
1272
1254
  constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1273
1255
 
1274
1256
  // kbc == k block continuous, current index in continuous ijk space.
1275
- int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1276
- const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1257
+ int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1258
+ const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1277
1259
 
1278
1260
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1279
1261
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1283,17 +1265,19 @@ static __global__ void flash_attn_ext_f16(
1283
1265
  int kb0_start = kbc % iter_k;
1284
1266
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1285
1267
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1286
- const int channel = kbc / (iter_k*iter_j);
1287
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1268
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1269
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1270
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1288
1271
 
1289
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1290
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1292
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1272
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1273
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1274
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1275
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1276
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1293
1277
 
1294
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1278
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1295
1279
 
1296
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1280
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1297
1281
 
1298
1282
  const int kb0_start_kernel = kb0_start * kb_niter;
1299
1283
  const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1322,17 +1306,19 @@ static __global__ void flash_attn_ext_f16(
1322
1306
  return;
1323
1307
  }
1324
1308
 
1325
- const int channel = kbc / (iter_k*iter_j);
1326
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1309
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1310
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1311
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1327
1312
 
1328
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1329
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1331
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1313
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1314
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1315
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1316
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1317
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1332
1318
 
1333
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1319
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1334
1320
 
1335
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1321
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1336
1322
 
1337
1323
  const int kb0_start_kernel = kb0_start * kb_niter;
1338
1324
  const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1344,15 +1330,16 @@ static __global__ void flash_attn_ext_f16(
1344
1330
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1345
1331
  #else
1346
1332
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
1347
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
1348
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1349
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
1350
- GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1351
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
1352
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1353
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1354
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1355
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
1333
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
1334
+ GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1335
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
1336
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
1337
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1338
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
1339
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
1340
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
1341
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
1342
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
1356
1343
  NO_DEVICE_CODE;
1357
1344
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1358
1345
  }
@@ -6,7 +6,7 @@
6
6
 
7
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
11
  static __global__ void flash_attn_tile_ext_f16(
12
12
  const char * __restrict__ Q,
@@ -21,29 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
21
21
  const float m1,
22
22
  const uint32_t n_head_log2,
23
23
  const float logit_softcap,
24
- const int ne00,
25
- const int ne01,
26
- const int ne02,
27
- const int ne03,
28
- const int ne10,
29
- const int ne11,
30
- const int ne12,
31
- const int ne13,
32
- const int ne31,
33
- const int nb31,
34
- const int nb01,
35
- const int nb02,
36
- const int nb03,
37
- const int nb11,
38
- const int nb12,
39
- const int nb13,
40
- const int nb21,
41
- const int nb22,
42
- const int nb23,
43
- const int ne0,
44
- const int ne1,
45
- const int ne2,
46
- const int ne3) {
24
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
26
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
28
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
29
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
30
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
47
31
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
48
32
 
49
33
  // Skip unused kernel variants for faster compilation:
@@ -60,15 +44,17 @@ static __global__ void flash_attn_tile_ext_f16(
60
44
 
61
45
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
62
46
 
47
+ const int sequence = blockIdx.z / ne02;
48
+ const int head = blockIdx.z - sequence*ne02;
63
49
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
64
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67
- const half * maskh = (const half *) mask + ne11*ic0;
50
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
51
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
52
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
53
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
68
54
 
69
55
  const int stride_KV2 = nb11 / sizeof(half2);
70
56
 
71
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
57
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
72
58
  const half slopeh = __float2half(slopef);
73
59
 
74
60
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -121,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
121
107
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
122
108
  const int k_KQ = k_KQ_0 + threadIdx.x;
123
109
 
124
- KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
110
+ KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
125
111
  }
126
112
  }
127
113
 
@@ -215,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
215
201
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
216
202
  const int i = i0 + threadIdx.x;
217
203
 
218
- KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
204
+ KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
219
205
  }
220
206
  }
221
207
 
@@ -253,6 +239,8 @@ static __global__ void flash_attn_tile_ext_f16(
253
239
  __syncthreads();
254
240
  }
255
241
 
242
+ float2 * dst2 = (float2 *) dst;
243
+
256
244
  #pragma unroll
257
245
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
258
246
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -264,21 +252,21 @@ static __global__ void flash_attn_tile_ext_f16(
264
252
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
265
253
  kqsum_j = warp_reduce_sum((float)kqsum_j);
266
254
 
255
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
256
+
267
257
  #pragma unroll
268
- for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
269
- const int i0 = i00 + 2*threadIdx.x;
258
+ for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
259
+ const int i0 = i00 + threadIdx.x;
270
260
 
271
- half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
261
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
272
262
  if (gridDim.y == 1) {
273
263
  dst_val /= __half2half2(kqsum_j);
274
264
  }
275
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
276
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
277
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
265
+ dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
278
266
  }
279
267
 
280
268
  if (gridDim.y != 1 && threadIdx.x == 0) {
281
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
269
+ dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
282
270
  }
283
271
  }
284
272
  #else
@@ -288,12 +276,11 @@ static __global__ void flash_attn_tile_ext_f16(
288
276
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
289
277
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
290
278
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
291
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
292
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
279
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
280
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293
281
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
294
282
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
295
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
296
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
283
+ GGML_UNUSED(nb23);
297
284
  NO_DEVICE_CODE;
298
285
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
299
286
  }