llama_cpp 0.15.4 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,319 @@
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f16.cuh"
4
+
5
+ #define FATTN_KQ_STRIDE_TILE_F16 64
6
+
7
+ template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
8
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
+ static __global__ void flash_attn_tile_ext_f16(
12
+ const char * __restrict__ Q,
13
+ const char * __restrict__ K,
14
+ const char * __restrict__ V,
15
+ const char * __restrict__ mask,
16
+ float * __restrict__ dst,
17
+ float2 * __restrict__ dst_meta,
18
+ const float scale,
19
+ const float max_bias,
20
+ const float m0,
21
+ const float m1,
22
+ const uint32_t n_head_log2,
23
+ const int ne00,
24
+ const int ne01,
25
+ const int ne02,
26
+ const int ne03,
27
+ const int ne10,
28
+ const int ne11,
29
+ const int ne12,
30
+ const int ne13,
31
+ const int ne31,
32
+ const int nb31,
33
+ const int nb01,
34
+ const int nb02,
35
+ const int nb03,
36
+ const int nb11,
37
+ const int nb12,
38
+ const int nb13,
39
+ const int nb21,
40
+ const int nb22,
41
+ const int nb23,
42
+ const int ne0,
43
+ const int ne1,
44
+ const int ne2,
45
+ const int ne3) {
46
+ #if FP16_AVAILABLE
47
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
48
+
49
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
50
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
51
+
52
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
53
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
54
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
55
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
56
+ const half * maskh = (const half *) mask + ne11*ic0;
57
+
58
+ const int stride_KV2 = nb11 / sizeof(half2);
59
+
60
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
61
+ const half slopeh = __float2half(slopef);
62
+
63
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
64
+
65
+ __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
66
+ half2 * KQ2 = (half2 *) KQ;
67
+
68
+ __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
69
+
70
+ half kqmax[ncols/nwarps];
71
+ #pragma unroll
72
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
73
+ kqmax[j0/nwarps] = -HALF_MAX_HALF;
74
+ }
75
+ half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
76
+
77
+ half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
78
+
79
+ // Convert Q to half2 and store in registers:
80
+ __shared__ half2 Q_h2[ncols][D/2];
81
+ #pragma unroll
82
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
83
+ const int j = j0 + threadIdx.y;
84
+
85
+ #pragma unroll
86
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
87
+ const int i = i0 + threadIdx.x;
88
+
89
+ const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
90
+ Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
91
+ }
92
+ }
93
+
94
+ __syncthreads();
95
+
96
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
97
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
98
+ // Calculate KQ tile and keep track of new maximum KQ values:
99
+
100
+ half kqmax_new[ncols/nwarps];
101
+ #pragma unroll
102
+ for (int j = 0; j < ncols/nwarps; ++j) {
103
+ kqmax_new[j] = kqmax[j];
104
+ }
105
+
106
+ #pragma unroll
107
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
108
+ const int i_KQ = i_KQ_0 + threadIdx.y;
109
+
110
+ #pragma unroll
111
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
112
+ const int k_KQ = k_KQ_0 + threadIdx.x;
113
+
114
+ KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
115
+ }
116
+ }
117
+
118
+ __syncthreads();
119
+
120
+ half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
121
+
122
+ #pragma unroll
123
+ for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
124
+ half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
125
+ half2 Q_k[ncols/nwarps];
126
+
127
+ #pragma unroll
128
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
129
+ const int i_KQ = i_KQ_0 + threadIdx.x;
130
+
131
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
132
+ }
133
+ #pragma unroll
134
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
135
+ const int j_KQ = j_KQ_0 + threadIdx.y;
136
+
137
+ Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
138
+ }
139
+
140
+ #pragma unroll
141
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
142
+ #pragma unroll
143
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
144
+ sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
145
+ }
146
+ }
147
+ }
148
+
149
+ #pragma unroll
150
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
151
+ const int i_KQ = i_KQ_0 + threadIdx.x;
152
+
153
+ #pragma unroll
154
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
155
+ const int j_KQ = j_KQ_0 + threadIdx.y;
156
+
157
+ half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
158
+ sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
159
+
160
+ kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
161
+
162
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
163
+ }
164
+ }
165
+
166
+ __syncthreads();
167
+
168
+ #pragma unroll
169
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
170
+ const int j = j0 + threadIdx.y;
171
+
172
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
173
+ const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
174
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
175
+
176
+ #pragma unroll
177
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
178
+ const int i = i0 + threadIdx.x;
179
+
180
+ const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
181
+ const half2 val = h2exp(diff);
182
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
183
+ KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
184
+ }
185
+
186
+ #pragma unroll
187
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
188
+ VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
189
+ }
190
+ }
191
+
192
+ __syncthreads();
193
+
194
+ #pragma unroll
195
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
196
+ const int k = k0 + threadIdx.y;
197
+
198
+ #pragma unroll
199
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
200
+ const int i = i0 + threadIdx.x;
201
+
202
+ KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
203
+ }
204
+ }
205
+
206
+ __syncthreads();
207
+
208
+ #pragma unroll
209
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
210
+ half2 V_k[(D/2)/WARP_SIZE][2];
211
+ half2 KQ_k[ncols/nwarps];
212
+
213
+ #pragma unroll
214
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
215
+ const int i = i0 + threadIdx.x;
216
+
217
+ V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
218
+ V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
219
+ }
220
+ #pragma unroll
221
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
222
+ const int j = j0 + threadIdx.y;
223
+
224
+ KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
225
+ }
226
+
227
+ #pragma unroll
228
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
229
+ #pragma unroll
230
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
231
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
232
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
233
+ }
234
+ }
235
+ }
236
+
237
+ __syncthreads();
238
+ }
239
+
240
+ #pragma unroll
241
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
242
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
243
+
244
+ if (ic0 + j_VKQ >= ne01) {
245
+ return;
246
+ }
247
+
248
+ half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
249
+ kqsum_j = warp_reduce_sum(kqsum_j);
250
+
251
+ #pragma unroll
252
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
253
+ const int i0 = i00 + 2*threadIdx.x;
254
+
255
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
256
+ if (parallel_blocks == 1) {
257
+ dst_val /= __half2half2(kqsum_j);
258
+ }
259
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
260
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
261
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
262
+ }
263
+
264
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
265
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
266
+ }
267
+ }
268
+ #else
269
+ NO_DEVICE_CODE;
270
+ #endif // FP16_AVAILABLE
271
+ }
272
+
273
+ template <int cols_per_block, int parallel_blocks>
274
+ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275
+ const ggml_tensor * Q = dst->src[0];
276
+ switch (Q->ne[0]) {
277
+ case 64: {
278
+ constexpr int D = 64;
279
+ constexpr int nwarps = 8;
280
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
281
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
282
+ } break;
283
+ case 128: {
284
+ constexpr int D = 128;
285
+ constexpr int nwarps = 8;
286
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
287
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
288
+ } break;
289
+ default: {
290
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
291
+ } break;
292
+ }
293
+ }
294
+
295
+ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
296
+ const ggml_tensor * KQV = dst;
297
+ const ggml_tensor * Q = dst->src[0];
298
+
299
+ const int32_t precision = KQV->op_params[2];
300
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
301
+
302
+ if (Q->ne[1] <= 16) {
303
+ constexpr int cols_per_block = 16;
304
+ constexpr int parallel_blocks = 4;
305
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
306
+ return;
307
+ }
308
+
309
+ if (Q->ne[1] <= 32) {
310
+ constexpr int cols_per_block = 32;
311
+ constexpr int parallel_blocks = 4;
312
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
313
+ return;
314
+ }
315
+
316
+ constexpr int cols_per_block = 32;
317
+ constexpr int parallel_blocks = 1;
318
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
319
+ }
@@ -0,0 +1,312 @@
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f32.cuh"
4
+
5
+ #define FATTN_KQ_STRIDE_TILE_F32 32
6
+
7
+ template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
8
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
+ static __global__ void flash_attn_tile_ext_f32(
12
+ const char * __restrict__ Q,
13
+ const char * __restrict__ K,
14
+ const char * __restrict__ V,
15
+ const char * __restrict__ mask,
16
+ float * __restrict__ dst,
17
+ float2 * __restrict__ dst_meta,
18
+ const float scale,
19
+ const float max_bias,
20
+ const float m0,
21
+ const float m1,
22
+ const uint32_t n_head_log2,
23
+ const int ne00,
24
+ const int ne01,
25
+ const int ne02,
26
+ const int ne03,
27
+ const int ne10,
28
+ const int ne11,
29
+ const int ne12,
30
+ const int ne13,
31
+ const int ne31,
32
+ const int nb31,
33
+ const int nb01,
34
+ const int nb02,
35
+ const int nb03,
36
+ const int nb11,
37
+ const int nb12,
38
+ const int nb13,
39
+ const int nb21,
40
+ const int nb22,
41
+ const int nb23,
42
+ const int ne0,
43
+ const int ne1,
44
+ const int ne2,
45
+ const int ne3) {
46
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
47
+
48
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
49
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
50
+
51
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
52
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
53
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
54
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
55
+ const half * maskh = (const half *) mask + ne11*ic0;
56
+
57
+ const int stride_KV2 = nb11 / sizeof(half2);
58
+
59
+ const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
60
+
61
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
62
+
63
+ __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
64
+
65
+ __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
66
+ float2 * KV_tmp2 = (float2 *) KV_tmp;
67
+
68
+ float kqmax[ncols/nwarps];
69
+ #pragma unroll
70
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
71
+ kqmax[j0/nwarps] = -FLT_MAX/2.0f;
72
+ }
73
+ float kqsum[ncols/nwarps] = {0.0f};
74
+
75
+ float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
76
+
77
+ // Convert Q to half2 and store in registers:
78
+ __shared__ float Q_f[ncols][D];
79
+ #pragma unroll
80
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
81
+ const int j = j0 + threadIdx.y;
82
+
83
+ #pragma unroll
84
+ for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
85
+ float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
86
+ Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
87
+ Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
88
+ }
89
+ }
90
+
91
+ __syncthreads();
92
+
93
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
94
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
95
+ // Calculate KQ tile and keep track of new maximum KQ values:
96
+
97
+ float kqmax_new[ncols/nwarps];
98
+ #pragma unroll
99
+ for (int j = 0; j < ncols/nwarps; ++j) {
100
+ kqmax_new[j] = kqmax[j];
101
+ }
102
+
103
+ #pragma unroll
104
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
105
+ const int i_KQ = i_KQ_0 + threadIdx.y;
106
+
107
+ #pragma unroll
108
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
109
+ const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
110
+ KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
111
+ KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
112
+ }
113
+ }
114
+
115
+ __syncthreads();
116
+
117
+ float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
118
+
119
+ #pragma unroll
120
+ for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
121
+ float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
122
+ float Q_k[ncols/nwarps];
123
+
124
+ #pragma unroll
125
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
126
+ const int i_KQ = i_KQ_0 + threadIdx.x;
127
+
128
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
129
+ }
130
+ #pragma unroll
131
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
132
+ const int j_KQ = j_KQ_0 + threadIdx.y;
133
+
134
+ Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
135
+ }
136
+
137
+ #pragma unroll
138
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
139
+ #pragma unroll
140
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
141
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
142
+ }
143
+ }
144
+ }
145
+
146
+ #pragma unroll
147
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
148
+ const int i_KQ = i_KQ_0 + threadIdx.x;
149
+
150
+ #pragma unroll
151
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
152
+ const int j_KQ = j_KQ_0 + threadIdx.y;
153
+
154
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
155
+
156
+ kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
157
+
158
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
159
+ }
160
+ }
161
+
162
+ __syncthreads();
163
+
164
+ #pragma unroll
165
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
166
+ const int j = j0 + threadIdx.y;
167
+
168
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
169
+ const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
170
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
171
+
172
+ float kqsum_add = 0.0f;
173
+ #pragma unroll
174
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
175
+ const int i = i0 + threadIdx.x;
176
+
177
+ const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
178
+ const float val = expf(diff);
179
+ kqsum_add += val;
180
+ KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
181
+ }
182
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
183
+
184
+ #pragma unroll
185
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
186
+ VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
187
+ VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
188
+ }
189
+ }
190
+
191
+ __syncthreads();
192
+
193
+ #pragma unroll
194
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
195
+ const int k = k0 + threadIdx.y;
196
+
197
+ #pragma unroll
198
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
199
+ const int i = i0 + threadIdx.x;
200
+
201
+ KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
202
+ KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
203
+ }
204
+ }
205
+
206
+ __syncthreads();
207
+
208
+ #pragma unroll
209
+ for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
210
+ float2 V_k[(D/2)/WARP_SIZE];
211
+ float KQ_k[ncols/nwarps];
212
+
213
+ #pragma unroll
214
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
215
+ const int i = i0 + threadIdx.x;
216
+
217
+ V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
218
+ }
219
+ #pragma unroll
220
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
221
+ const int j = j0 + threadIdx.y;
222
+
223
+ KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
224
+ }
225
+
226
+ #pragma unroll
227
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
228
+ #pragma unroll
229
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
230
+ VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
231
+ VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
232
+ }
233
+ }
234
+ }
235
+
236
+ __syncthreads();
237
+ }
238
+
239
+ #pragma unroll
240
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
241
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
242
+
243
+ if (ic0 + j_VKQ >= ne01) {
244
+ return;
245
+ }
246
+
247
+ float kqsum_j = kqsum[j_VKQ_0/nwarps];
248
+ kqsum_j = warp_reduce_sum(kqsum_j);
249
+
250
+ #pragma unroll
251
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
252
+ const int i0 = i00 + 2*threadIdx.x;
253
+
254
+ float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
255
+ if (parallel_blocks == 1) {
256
+ dst_val.x /= kqsum_j;
257
+ dst_val.y /= kqsum_j;
258
+ }
259
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
260
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
261
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
262
+ }
263
+
264
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
265
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
266
+ }
267
+ }
268
+ }
269
+
270
+ template <int cols_per_block, int parallel_blocks>
271
+ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
+ const ggml_tensor * Q = dst->src[0];
273
+ switch (Q->ne[0]) {
274
+ case 64: {
275
+ constexpr int D = 64;
276
+ constexpr int nwarps = 8;
277
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
278
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
279
+ } break;
280
+ case 128: {
281
+ constexpr int D = 128;
282
+ constexpr int nwarps = 8;
283
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
284
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
285
+ } break;
286
+ default: {
287
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
288
+ } break;
289
+ }
290
+ }
291
+
292
+ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
+ const ggml_tensor * Q = dst->src[0];
294
+
295
+ if (Q->ne[1] <= 16) {
296
+ constexpr int cols_per_block = 16;
297
+ constexpr int parallel_blocks = 4;
298
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
299
+ return;
300
+ }
301
+
302
+ if (Q->ne[1] <= 32) {
303
+ constexpr int cols_per_block = 32;
304
+ constexpr int parallel_blocks = 4;
305
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
306
+ return;
307
+ }
308
+
309
+ constexpr int cols_per_block = 32;
310
+ constexpr int parallel_blocks = 1;
311
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
312
+ }