llama_cpp 0.16.2 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -43
  5. data/ext/llama_cpp/llama_cpp.cpp +8 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +3 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1124
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22506
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2458
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1147
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,345 +0,0 @@
1
- #include "common.cuh"
2
- #include "fattn-common.cuh"
3
- #include "fattn-tile-f16.cuh"
4
- #include "fattn-tile-f32.cuh"
5
- #include "fattn-vec-f16.cuh"
6
- #include "fattn-vec-f32.cuh"
7
- #include "fattn-wmma-f16.cuh"
8
- #include "fattn.cuh"
9
-
10
- #include <cstdint>
11
-
12
- static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
- const ggml_tensor * KQV = dst;
14
- const ggml_tensor * Q = dst->src[0];
15
-
16
- const int32_t precision = KQV->op_params[2];
17
-
18
- if (precision != GGML_PREC_DEFAULT) {
19
- if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
- constexpr int cols_per_block = 16;
21
- switch (Q->ne[0]) {
22
- case 64:
23
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
- break;
25
- case 80:
26
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
- break;
28
- case 96:
29
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
- break;
31
- case 112:
32
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
- break;
34
- case 128:
35
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
- break;
37
- case 256:
38
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
- break;
40
- default:
41
- GGML_ASSERT(false);
42
- break;
43
- }
44
- } else {
45
- constexpr int cols_per_block = 32;
46
- switch (Q->ne[0]) {
47
- case 64:
48
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
- break;
50
- case 80:
51
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
- break;
53
- case 96:
54
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
- break;
56
- case 112:
57
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
- break;
59
- case 128:
60
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
- break;
62
- // case 256:
63
- // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
- // break;
65
- default:
66
- GGML_ASSERT(false);
67
- break;
68
- }
69
- }
70
- return;
71
- }
72
-
73
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
- constexpr int cols_per_block = 8;
75
- switch (Q->ne[0]) {
76
- case 64:
77
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
- break;
79
- case 96:
80
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
- break;
82
- case 128:
83
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
- break;
85
- case 256:
86
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
- break;
88
- default:
89
- GGML_ASSERT(false);
90
- break;
91
- }
92
- return;
93
- }
94
-
95
- if (Q->ne[1] <= 32) {
96
- constexpr int cols_per_block = 16;
97
- switch (Q->ne[0]) {
98
- case 64:
99
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
- break;
101
- case 80:
102
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
- break;
104
- case 96:
105
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
- break;
107
- case 112:
108
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
- break;
110
- case 128:
111
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
- break;
113
- case 256:
114
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
- break;
116
- default:
117
- GGML_ASSERT(false);
118
- break;
119
- }
120
- return;
121
- }
122
-
123
- constexpr int cols_per_block = 32;
124
- switch (Q->ne[0]) {
125
- case 64:
126
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
- break;
128
- case 80:
129
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
- break;
131
- case 96:
132
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
- break;
134
- case 112:
135
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
- break;
137
- case 128:
138
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
- break;
140
- case 256:
141
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
- break;
143
- default:
144
- GGML_ASSERT(false);
145
- break;
146
- }
147
- }
148
- #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
- if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
- ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
151
- return; \
152
- } \
153
-
154
- static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155
- ggml_tensor * Q = dst->src[1];
156
- ggml_tensor * K = dst->src[1];
157
- ggml_tensor * V = dst->src[2];
158
-
159
- #ifdef GGML_CUDA_FA_ALL_QUANTS
160
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
161
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
162
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
166
-
167
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
173
-
174
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
180
-
181
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
187
-
188
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
194
-
195
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
201
-
202
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208
-
209
- FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210
- #else
211
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212
-
213
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214
-
215
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217
- FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218
- #endif // GGML_CUDA_FA_ALL_QUANTS
219
-
220
- on_no_fattn_vec_case(Q->ne[0]);
221
- }
222
-
223
- #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
224
- if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
225
- ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
226
- return; \
227
- } \
228
-
229
- static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
- ggml_tensor * Q = dst->src[1];
231
- ggml_tensor * K = dst->src[1];
232
- ggml_tensor * V = dst->src[2];
233
-
234
- #ifdef GGML_CUDA_FA_ALL_QUANTS
235
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
236
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
237
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
238
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
239
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
240
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
241
-
242
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
243
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
244
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
245
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
246
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
247
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
248
-
249
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
250
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
251
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
252
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
253
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
254
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
255
-
256
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
257
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
258
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
259
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
260
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
261
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
262
-
263
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
264
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
265
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
266
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
267
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
268
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
269
-
270
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
271
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
272
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
273
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
274
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
275
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
276
-
277
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
278
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
279
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
280
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
281
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
282
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
283
-
284
- FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
285
- #else
286
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
287
-
288
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289
-
290
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
291
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
292
- FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
293
- #endif // GGML_CUDA_FA_ALL_QUANTS
294
-
295
- on_no_fattn_vec_case(Q->ne[0]);
296
- }
297
-
298
- void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
- const ggml_tensor * KQV = dst;
300
- const ggml_tensor * Q = dst->src[0];
301
-
302
- ggml_cuda_set_device(ctx.device);
303
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
- const int32_t precision = KQV->op_params[2];
305
-
306
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
- if (cc >= CC_OFFSET_AMD) {
308
- if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
- } else {
311
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
312
- }
313
- return;
314
- }
315
-
316
- if (!fast_fp16_available(cc)) {
317
- if (Q->ne[1] <= 8) {
318
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
- } else {
320
- ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
321
- }
322
- return;
323
- }
324
-
325
- if (!fp16_mma_available(cc)) {
326
- if (Q->ne[1] <= 8) {
327
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
328
- } else {
329
- ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
330
- }
331
- return;
332
- }
333
-
334
- if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
- if (precision == GGML_PREC_DEFAULT) {
336
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
- return;
338
- } else if(Q->ne[0] <= 128) {
339
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
340
- return;
341
- }
342
- }
343
-
344
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
- }
@@ -1,178 +0,0 @@
1
- #include "getrows.cuh"
2
- #include "dequantize.cuh"
3
-
4
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5
- static __global__ void k_get_rows(
6
- const void * src0, const int32_t * src1, dst_t * dst,
7
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
8
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
9
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
10
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
11
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
12
-
13
- const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15
- const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16
- const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
17
-
18
- if (i00 >= ne00) {
19
- return;
20
- }
21
-
22
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
23
-
24
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
26
-
27
- const int ib = i00/qk; // block index
28
- const int iqs = (i00%qk)/qr; // quant index
29
- const int iybs = i00 - i00%qk; // dst block start index
30
- const int y_offset = qr == 1 ? 1 : qk/2;
31
-
32
- // dequantize
33
- dfloat2 v;
34
- dequantize_kernel(src0_row, ib, iqs, v);
35
-
36
- dst_row[iybs + iqs + 0] = v.x;
37
- dst_row[iybs + iqs + y_offset] = v.y;
38
- }
39
-
40
- template<typename src0_t, typename dst_t>
41
- static __global__ void k_get_rows_float(
42
- const src0_t * src0, const int32_t * src1, dst_t * dst,
43
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
44
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
45
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
46
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
47
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
48
-
49
- const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51
- const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52
- const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
53
-
54
- if (i00 >= ne00) {
55
- return;
56
- }
57
-
58
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
59
-
60
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
-
63
- dst_row[i00] = src0_row[i00];
64
- }
65
-
66
- template<int qk, int qr, dequantize_kernel_t dq>
67
- static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
68
- const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
69
-
70
- GGML_TENSOR_BINARY_OP_LOCALS
71
-
72
- const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
73
- const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
74
- const dim3 block_nums(block_num_x, ne10, ne11*ne12);
75
-
76
- // strides in elements
77
- //const size_t s0 = nb0 / ggml_element_size(dst);
78
- const size_t s1 = nb1 / ggml_element_size(dst);
79
- const size_t s2 = nb2 / ggml_element_size(dst);
80
- const size_t s3 = nb3 / ggml_element_size(dst);
81
-
82
- const size_t s10 = nb10 / ggml_element_size(src1);
83
- const size_t s11 = nb11 / ggml_element_size(src1);
84
- const size_t s12 = nb12 / ggml_element_size(src1);
85
- //const size_t s13 = nb13 / ggml_element_size(src1);
86
-
87
- GGML_ASSERT(ne00 % 2 == 0);
88
-
89
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
90
- src0_dd, src1_dd, dst_dd,
91
- ne00, /*ne01, ne02, ne03,*/
92
- /*ne10, ne11,*/ ne12, /*ne13,*/
93
- /* s0,*/ s1, s2, s3,
94
- /* nb00,*/ nb01, nb02, nb03,
95
- s10, s11, s12/*, s13*/);
96
-
97
- GGML_UNUSED(dst);
98
- }
99
-
100
- template<typename src0_t>
101
- static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
102
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
103
-
104
- GGML_TENSOR_BINARY_OP_LOCALS
105
-
106
- const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
107
- const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
108
- const dim3 block_nums(block_num_x, ne10, ne11*ne12);
109
-
110
- // strides in elements
111
- //const size_t s0 = nb0 / ggml_element_size(dst);
112
- const size_t s1 = nb1 / ggml_element_size(dst);
113
- const size_t s2 = nb2 / ggml_element_size(dst);
114
- const size_t s3 = nb3 / ggml_element_size(dst);
115
-
116
- const size_t s10 = nb10 / ggml_element_size(src1);
117
- const size_t s11 = nb11 / ggml_element_size(src1);
118
- const size_t s12 = nb12 / ggml_element_size(src1);
119
- //const size_t s13 = nb13 / ggml_element_size(src1);
120
-
121
- k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
122
- src0_dd, src1_dd, dst_dd,
123
- ne00, /*ne01, ne02, ne03,*/
124
- /*ne10, ne11,*/ ne12, /*ne13,*/
125
- /* s0,*/ s1, s2, s3,
126
- /* nb00,*/ nb01, nb02, nb03,
127
- s10, s11, s12/*, s13*/);
128
-
129
- GGML_UNUSED(dst);
130
- }
131
-
132
- void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
133
- const ggml_tensor * src0 = dst->src[0];
134
- const ggml_tensor * src1 = dst->src[1];
135
- const float * src0_d = (const float *)src0->data;
136
- const float * src1_d = (const float *)src1->data;
137
- float * dst_d = (float *)dst->data;
138
- cudaStream_t stream = ctx.stream();
139
-
140
-
141
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
142
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
143
-
144
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
145
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
146
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
147
-
148
- const int32_t * src1_i32 = (const int32_t *) src1_d;
149
-
150
- switch (src0->type) {
151
- case GGML_TYPE_F16:
152
- get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
153
- break;
154
- case GGML_TYPE_F32:
155
- get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
156
- break;
157
- case GGML_TYPE_Q4_0:
158
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
159
- break;
160
- case GGML_TYPE_Q4_1:
161
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
162
- break;
163
- case GGML_TYPE_Q5_0:
164
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
165
- break;
166
- case GGML_TYPE_Q5_1:
167
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
168
- break;
169
- case GGML_TYPE_Q8_0:
170
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
171
- break;
172
- default:
173
- // TODO: k-quants
174
- fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
175
- GGML_ASSERT(false);
176
- break;
177
- }
178
- }
@@ -1,104 +0,0 @@
1
- #include "im2col.cuh"
2
-
3
- template <typename T>
4
- static __global__ void im2col_kernel(
5
- const float * x, T * dst, int64_t batch_offset,
6
- int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
7
- int s0, int s1, int p0, int p1, int d0, int d1) {
8
- const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
9
- if (i >= pelements) {
10
- return;
11
- }
12
-
13
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
14
- const int64_t kx = i / ksize;
15
- const int64_t kd = kx * ksize;
16
- const int64_t ky = (i - kd) / OW;
17
- const int64_t ix = i % OW;
18
-
19
- const int64_t oh = blockIdx.y;
20
- const int64_t batch = blockIdx.z / IC;
21
- const int64_t ic = blockIdx.z % IC;
22
-
23
- const int64_t iiw = ix * s0 + kx * d0 - p0;
24
- const int64_t iih = oh * s1 + ky * d1 - p1;
25
-
26
- const int64_t offset_dst =
27
- ((batch * OH + oh) * OW + ix) * CHW +
28
- (ic * (KW * KH) + ky * KW + kx);
29
-
30
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
31
- dst[offset_dst] = 0.0f;
32
- } else {
33
- const int64_t offset_src = ic * offset_delta + batch * batch_offset;
34
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
35
- }
36
- }
37
-
38
- template <typename T>
39
- static void im2col_cuda(const float * x, T* dst,
40
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
41
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
42
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
43
- const int parallel_elements = OW * KW * KH;
44
- const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
45
- dim3 block_nums(num_blocks, OH, batch * IC);
46
- im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
47
- }
48
-
49
- static void im2col_cuda_f16(const float * x, half * dst,
50
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
51
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
52
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
53
-
54
- im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
55
- }
56
-
57
- static void im2col_cuda_f32(const float * x, float * dst,
58
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
59
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
60
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
61
-
62
- im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
63
- }
64
-
65
- void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
66
- const ggml_tensor * src0 = dst->src[0];
67
- const ggml_tensor * src1 = dst->src[1];
68
- const float * src1_d = (const float *)src1->data;
69
- float * dst_d = (float *)dst->data;
70
- cudaStream_t stream = ctx.stream();
71
-
72
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
73
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
74
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
75
-
76
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
77
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
78
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
79
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
80
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
81
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
82
-
83
- const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
84
-
85
- const int64_t IC = src1->ne[is_2D ? 2 : 1];
86
- const int64_t IH = is_2D ? src1->ne[1] : 1;
87
- const int64_t IW = src1->ne[0];
88
-
89
- const int64_t KH = is_2D ? src0->ne[1] : 1;
90
- const int64_t KW = src0->ne[0];
91
-
92
- const int64_t OH = is_2D ? dst->ne[2] : 1;
93
- const int64_t OW = dst->ne[1];
94
-
95
- const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
96
- const int64_t batch = src1->ne[3];
97
- const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
98
-
99
- if(dst->type == GGML_TYPE_F16) {
100
- im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
101
- } else {
102
- im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
103
- }
104
- }