@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -23,31 +23,13 @@ typedef void (* fattn_kernel_t)(
23
23
  const float m1,
24
24
  const uint32_t n_head_log2,
25
25
  const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int ne32,
36
- const int nb31,
37
- const int nb32,
38
- const int nb01,
39
- const int nb02,
40
- const int nb03,
41
- const int nb11,
42
- const int nb12,
43
- const int nb13,
44
- const int nb21,
45
- const int nb22,
46
- const int nb23,
47
- const int ne0,
48
- const int ne1,
49
- const int ne2,
50
- const int ne3);
26
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
27
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
28
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
29
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
30
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
31
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
32
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
51
33
 
52
34
  typedef half (*vec_dot_KQ_f16_t)(
53
35
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -521,7 +503,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521
503
  template<int D, int ncols1, int ncols2> // D == head size
522
504
  __launch_bounds__(D, 1)
523
505
  static __global__ void flash_attn_stream_k_fixup(
524
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
506
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
525
507
  constexpr int ncols = ncols1*ncols2;
526
508
 
527
509
  const int bidx0 = blockIdx.x;
@@ -535,8 +517,8 @@ static __global__ void flash_attn_stream_k_fixup(
535
517
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
536
518
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
537
519
 
538
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
539
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
520
+ const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
521
+ const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
540
522
 
541
523
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
542
524
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +527,15 @@ static __global__ void flash_attn_stream_k_fixup(
545
527
  return;
546
528
  }
547
529
 
548
- const int channel = kbc0 / (iter_k*iter_j);
549
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
530
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
531
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
532
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
550
533
 
551
534
  if (jt*ncols1 + j >= ne01) {
552
535
  return;
553
536
  }
554
537
 
555
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
538
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
556
539
 
557
540
  // Load the partial result that needs a fixup:
558
541
  float dst_val = 0.0f;
@@ -571,7 +554,7 @@ static __global__ void flash_attn_stream_k_fixup(
571
554
  int bidx = bidx0 - 1;
572
555
  int kbc_stop = kbc0;
573
556
  while(true) {
574
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
557
+ const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
575
558
  if (kbc == kbc_stop) { // Did not have any data.
576
559
  bidx--;
577
560
  kbc_stop = kbc;
@@ -617,16 +600,31 @@ static __global__ void flash_attn_combine_results(
617
600
  const float2 * __restrict__ VKQ_meta,
618
601
  float * __restrict__ dst,
619
602
  const int parallel_blocks) {
620
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
621
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
622
- dst += D * gridDim.z*blockIdx.x;
603
+ // Dimension 0: threadIdx.x
604
+ // Dimension 1: blockIdx.x
605
+ // Dimension 2: blockIdx.y
606
+ // Dimension 3: blockIdx.z
607
+ // Memory layout is permuted with [0, 2, 1, 3]
608
+
609
+ const int ne01 = gridDim.x;
610
+ const int ne02 = gridDim.y;
611
+
612
+ const int col = blockIdx.x;
613
+ const int head = blockIdx.y;
614
+ const int sequence = blockIdx.z;
615
+
616
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
617
+
618
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
619
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
620
+ dst += j_dst_unrolled * D;
623
621
 
624
622
  const int tid = threadIdx.x;
625
623
  __builtin_assume(tid < D);
626
624
 
627
625
  extern __shared__ float2 meta[];
628
626
  for (int i = tid; i < 2*parallel_blocks; i += D) {
629
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
627
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
630
628
  }
631
629
 
632
630
  __syncthreads();
@@ -644,11 +642,11 @@ static __global__ void flash_attn_combine_results(
644
642
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
645
643
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
646
644
 
647
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
645
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
648
646
  VKQ_denominator += KQ_max_scale * meta[l].y;
649
647
  }
650
648
 
651
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
649
+ dst[tid] = VKQ_numerator / VKQ_denominator;
652
650
  }
653
651
 
654
652
  [[noreturn]]
@@ -705,8 +703,6 @@ void launch_fattn(
705
703
 
706
704
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
707
705
 
708
- GGML_ASSERT(Q->ne[3] == 1);
709
-
710
706
  ggml_cuda_pool & pool = ctx.pool();
711
707
  cudaStream_t main_stream = ctx.stream();
712
708
  const int id = ggml_cuda_get_device();
@@ -729,33 +725,58 @@ void launch_fattn(
729
725
  size_t nb23 = V ? V->nb[3] : nb13;
730
726
 
731
727
  if (need_f16_K && K->type != GGML_TYPE_F16) {
732
- GGML_ASSERT(ggml_is_contiguously_allocated(K));
733
- K_f16.alloc(ggml_nelements(K));
734
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
735
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
736
- K_data = (char *) K_f16.ptr;
737
-
738
728
  const size_t bs = ggml_blck_size(K->type);
739
729
  const size_t ts = ggml_type_size(K->type);
740
730
 
741
- nb11 = nb11*bs*sizeof(half)/ts;
742
- nb12 = nb12*bs*sizeof(half)/ts;
743
- nb13 = nb13*bs*sizeof(half)/ts;
731
+ K_f16.alloc(ggml_nelements(K));
732
+ if (ggml_is_contiguously_allocated(K)) {
733
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
734
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
735
+
736
+ nb11 = nb11*bs*sizeof(half)/ts;
737
+ nb12 = nb12*bs*sizeof(half)/ts;
738
+ nb13 = nb13*bs*sizeof(half)/ts;
739
+ } else {
740
+ GGML_ASSERT(K->nb[0] == ts);
741
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
742
+ const int64_t s01 = nb11 / ts;
743
+ const int64_t s02 = nb12 / ts;
744
+ const int64_t s03 = nb13 / ts;
745
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
746
+
747
+ nb11 = K->ne[0] * sizeof(half);
748
+ nb12 = K->ne[1] * nb11;
749
+ nb13 = K->ne[2] * nb12;
750
+ }
751
+ K_data = (char *) K_f16.ptr;
744
752
  }
745
753
 
746
754
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
747
- GGML_ASSERT(ggml_is_contiguously_allocated(V));
748
- V_f16.alloc(ggml_nelements(V));
749
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
750
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
751
- V_data = (char *) V_f16.ptr;
752
-
753
755
  const size_t bs = ggml_blck_size(V->type);
754
756
  const size_t ts = ggml_type_size(V->type);
755
757
 
756
- nb21 = nb21*bs*sizeof(half)/ts;
757
- nb22 = nb22*bs*sizeof(half)/ts;
758
- nb23 = nb23*bs*sizeof(half)/ts;
758
+ V_f16.alloc(ggml_nelements(V));
759
+ if (ggml_is_contiguously_allocated(V)) {
760
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
761
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
762
+ V_data = (char *) V_f16.ptr;
763
+
764
+ nb21 = nb21*bs*sizeof(half)/ts;
765
+ nb22 = nb22*bs*sizeof(half)/ts;
766
+ nb23 = nb23*bs*sizeof(half)/ts;
767
+ } else {
768
+ GGML_ASSERT(V->nb[0] == ts);
769
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
770
+ const int64_t s01 = nb21 / ts;
771
+ const int64_t s02 = nb22 / ts;
772
+ const int64_t s03 = nb23 / ts;
773
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
774
+
775
+ nb21 = V->ne[0] * sizeof(half);
776
+ nb22 = V->ne[1] * nb21;
777
+ nb23 = V->ne[2] * nb22;
778
+ }
779
+ V_data = (char *) V_f16.ptr;
759
780
  }
760
781
 
761
782
  int parallel_blocks = 1;
@@ -851,14 +872,11 @@ void launch_fattn(
851
872
  mask ? ((const char *) mask->data) : nullptr,
852
873
  !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
853
874
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
855
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
856
- mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
858
- Q->nb[1], Q->nb[2], Q->nb[3],
859
- nb11, nb12, nb13,
875
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
876
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
860
877
  nb21, nb22, nb23,
861
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
878
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
879
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
862
880
  );
863
881
  CUDA_CHECK(cudaGetLastError());
864
882
 
@@ -869,11 +887,11 @@ void launch_fattn(
869
887
 
870
888
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871
889
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
872
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
890
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
873
891
  }
874
892
  } else if (parallel_blocks > 1) {
875
893
  const dim3 block_dim_combine(DV, 1, 1);
876
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
894
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
877
895
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
878
896
 
879
897
  flash_attn_combine_results<DV>
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408
408
  const int stride_K,
409
409
  const int stride_V,
410
410
  const int stride_mask,
411
- const int jt,
412
411
  half2 * const __restrict__ tile_Q,
413
412
  half2 * const __restrict__ tile_K,
414
413
  half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455
454
  cp_async_wait_all();
456
455
  __syncthreads();
457
456
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
457
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
459
458
  } else {
460
459
  constexpr bool use_cp_async = nstages == 1;
461
460
  if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471
470
  if (nstages <= 1) {
472
471
  constexpr bool use_cp_async = nstages == 1;
473
472
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474
- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
473
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
475
474
  if (use_cp_async) {
476
475
  cp_async_wait_all();
477
476
  }
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715
714
  (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
716
715
  }
717
716
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
717
+ (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719
718
  }
720
719
  }
721
720
 
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732
731
  if (nstages <= 1 && i0_start < reusable_cutoff) {
733
732
  constexpr bool use_cp_async = nstages == 1;
734
733
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735
- (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
734
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736
735
  if (use_cp_async) {
737
736
  cp_async_wait_all();
738
737
  }
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771
770
  GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
772
771
  GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
773
772
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
774
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
775
- GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
773
+ GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
776
774
  GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
777
775
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
778
776
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920
918
  (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
921
919
  }
922
920
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921
+ (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924
922
  }
925
923
 
926
924
  // Iterate over ne11 == previous tokens:
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
928
926
  constexpr bool last_iter = false;
929
927
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930
928
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932
930
  }
933
931
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934
932
  constexpr bool last_iter = true;
935
933
  flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936
934
  (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937
- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
935
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
938
936
  }
939
937
 
940
938
  // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,31 +1212,13 @@ static __global__ void flash_attn_ext_f16(
1214
1212
  const float m1,
1215
1213
  const uint32_t n_head_log2,
1216
1214
  const float logit_softcap,
1217
- const int ne00,
1218
- const int ne01,
1219
- const int ne02,
1220
- const int ne03,
1221
- const int ne10,
1222
- const int ne11,
1223
- const int ne12,
1224
- const int ne13,
1225
- const int ne31,
1226
- const int ne32,
1227
- const int nb31,
1228
- const int nb32,
1229
- const int nb01,
1230
- const int nb02,
1231
- const int nb03,
1232
- const int nb11,
1233
- const int nb12,
1234
- const int nb13,
1235
- const int nb21,
1236
- const int nb22,
1237
- const int nb23,
1238
- const int ne0,
1239
- const int ne1,
1240
- const int ne2,
1241
- const int ne3) {
1215
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1242
1222
  #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1243
1223
 
1244
1224
  // Skip unused kernel variants for faster compilation:
@@ -1274,8 +1254,8 @@ static __global__ void flash_attn_ext_f16(
1274
1254
  constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1275
1255
 
1276
1256
  // kbc == k block continuous, current index in continuous ijk space.
1277
- int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1278
- const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1257
+ int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1258
+ const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1279
1259
 
1280
1260
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1281
1261
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1265,19 @@ static __global__ void flash_attn_ext_f16(
1285
1265
  int kb0_start = kbc % iter_k;
1286
1266
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1287
1267
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1288
- const int channel = kbc / (iter_k*iter_j);
1289
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1268
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1269
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1270
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1290
1271
 
1291
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1292
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1272
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1273
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1293
1274
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294
- (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1295
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1275
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1276
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1296
1277
 
1297
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1278
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1298
1279
 
1299
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1280
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1300
1281
 
1301
1282
  const int kb0_start_kernel = kb0_start * kb_niter;
1302
1283
  const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1325,18 +1306,19 @@ static __global__ void flash_attn_ext_f16(
1325
1306
  return;
1326
1307
  }
1327
1308
 
1328
- const int channel = kbc / (iter_k*iter_j);
1329
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
1309
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1310
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1311
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1330
1312
 
1331
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1332
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1313
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1314
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1333
1315
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334
- (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1335
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1316
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1317
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1336
1318
 
1337
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1319
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1338
1320
 
1339
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1321
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1340
1322
 
1341
1323
  const int kb0_start_kernel = kb0_start * kb_niter;
1342
1324
  const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1348,15 +1330,16 @@ static __global__ void flash_attn_ext_f16(
1348
1330
  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1349
1331
  #else
1350
1332
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
1351
- GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
1352
- GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1353
- GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
1354
- GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1355
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1356
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1357
- GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1358
- GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1359
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
1333
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
1334
+ GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1335
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
1336
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
1337
+ GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1338
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
1339
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
1340
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
1341
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
1342
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
1360
1343
  NO_DEVICE_CODE;
1361
1344
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1362
1345
  }
@@ -21,31 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
21
21
  const float m1,
22
22
  const uint32_t n_head_log2,
23
23
  const float logit_softcap,
24
- const int ne00,
25
- const int ne01,
26
- const int ne02,
27
- const int ne03,
28
- const int ne10,
29
- const int ne11,
30
- const int ne12,
31
- const int ne13,
32
- const int ne31,
33
- const int ne32,
34
- const int nb31,
35
- const int nb32,
36
- const int nb01,
37
- const int nb02,
38
- const int nb03,
39
- const int nb11,
40
- const int nb12,
41
- const int nb13,
42
- const int nb21,
43
- const int nb22,
44
- const int nb23,
45
- const int ne0,
46
- const int ne1,
47
- const int ne2,
48
- const int ne3) {
24
+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
25
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
26
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
27
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
28
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
29
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
30
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
49
31
  #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
50
32
 
51
33
  // Skip unused kernel variants for faster compilation:
@@ -62,15 +44,17 @@ static __global__ void flash_attn_tile_ext_f16(
62
44
 
63
45
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
64
46
 
47
+ const int sequence = blockIdx.z / ne02;
48
+ const int head = blockIdx.z - sequence*ne02;
65
49
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
66
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
50
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
51
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
52
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
53
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
70
54
 
71
55
  const int stride_KV2 = nb11 / sizeof(half2);
72
56
 
73
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
57
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
74
58
  const half slopeh = __float2half(slopef);
75
59
 
76
60
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -123,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
123
107
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
124
108
  const int k_KQ = k_KQ_0 + threadIdx.x;
125
109
 
126
- KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
110
+ KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
127
111
  }
128
112
  }
129
113
 
@@ -217,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
217
201
  for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
218
202
  const int i = i0 + threadIdx.x;
219
203
 
220
- KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
204
+ KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
221
205
  }
222
206
  }
223
207
 
@@ -255,6 +239,8 @@ static __global__ void flash_attn_tile_ext_f16(
255
239
  __syncthreads();
256
240
  }
257
241
 
242
+ float2 * dst2 = (float2 *) dst;
243
+
258
244
  #pragma unroll
259
245
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
260
246
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +252,21 @@ static __global__ void flash_attn_tile_ext_f16(
266
252
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
267
253
  kqsum_j = warp_reduce_sum((float)kqsum_j);
268
254
 
255
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
256
+
269
257
  #pragma unroll
270
- for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
271
- const int i0 = i00 + 2*threadIdx.x;
258
+ for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
259
+ const int i0 = i00 + threadIdx.x;
272
260
 
273
- half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
261
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
274
262
  if (gridDim.y == 1) {
275
263
  dst_val /= __half2half2(kqsum_j);
276
264
  }
277
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
278
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
279
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
265
+ dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
280
266
  }
281
267
 
282
268
  if (gridDim.y != 1 && threadIdx.x == 0) {
283
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
269
+ dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
284
270
  }
285
271
  }
286
272
  #else
@@ -290,12 +276,11 @@ static __global__ void flash_attn_tile_ext_f16(
290
276
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
291
277
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
292
278
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
293
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
279
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
280
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
295
281
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
296
282
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
297
- GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
298
- GGML_UNUSED(ne2); GGML_UNUSED(ne3);
283
+ GGML_UNUSED(nb23);
299
284
  NO_DEVICE_CODE;
300
285
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
301
286
  }