llama_cpp 0.16.2 → 0.17.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -43
  5. data/ext/llama_cpp/llama_cpp.cpp +8 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +3 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1124
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22506
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2458
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1147
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,345 +0,0 @@
1
- #include "common.cuh"
2
- #include "fattn-common.cuh"
3
- #include "fattn-tile-f16.cuh"
4
- #include "fattn-tile-f32.cuh"
5
- #include "fattn-vec-f16.cuh"
6
- #include "fattn-vec-f32.cuh"
7
- #include "fattn-wmma-f16.cuh"
8
- #include "fattn.cuh"
9
-
10
- #include <cstdint>
11
-
12
- static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
- const ggml_tensor * KQV = dst;
14
- const ggml_tensor * Q = dst->src[0];
15
-
16
- const int32_t precision = KQV->op_params[2];
17
-
18
- if (precision != GGML_PREC_DEFAULT) {
19
- if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
- constexpr int cols_per_block = 16;
21
- switch (Q->ne[0]) {
22
- case 64:
23
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
- break;
25
- case 80:
26
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
- break;
28
- case 96:
29
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
- break;
31
- case 112:
32
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
- break;
34
- case 128:
35
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
- break;
37
- case 256:
38
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
- break;
40
- default:
41
- GGML_ASSERT(false);
42
- break;
43
- }
44
- } else {
45
- constexpr int cols_per_block = 32;
46
- switch (Q->ne[0]) {
47
- case 64:
48
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
- break;
50
- case 80:
51
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
- break;
53
- case 96:
54
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
- break;
56
- case 112:
57
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
- break;
59
- case 128:
60
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
- break;
62
- // case 256:
63
- // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
- // break;
65
- default:
66
- GGML_ASSERT(false);
67
- break;
68
- }
69
- }
70
- return;
71
- }
72
-
73
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
- constexpr int cols_per_block = 8;
75
- switch (Q->ne[0]) {
76
- case 64:
77
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
- break;
79
- case 96:
80
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
- break;
82
- case 128:
83
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
- break;
85
- case 256:
86
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
- break;
88
- default:
89
- GGML_ASSERT(false);
90
- break;
91
- }
92
- return;
93
- }
94
-
95
- if (Q->ne[1] <= 32) {
96
- constexpr int cols_per_block = 16;
97
- switch (Q->ne[0]) {
98
- case 64:
99
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
- break;
101
- case 80:
102
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
- break;
104
- case 96:
105
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
- break;
107
- case 112:
108
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
- break;
110
- case 128:
111
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
- break;
113
- case 256:
114
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
- break;
116
- default:
117
- GGML_ASSERT(false);
118
- break;
119
- }
120
- return;
121
- }
122
-
123
- constexpr int cols_per_block = 32;
124
- switch (Q->ne[0]) {
125
- case 64:
126
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
- break;
128
- case 80:
129
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
- break;
131
- case 96:
132
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
- break;
134
- case 112:
135
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
- break;
137
- case 128:
138
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
- break;
140
- case 256:
141
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
- break;
143
- default:
144
- GGML_ASSERT(false);
145
- break;
146
- }
147
- }
148
- #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
- if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
- ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
151
- return; \
152
- } \
153
-
154
- static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155
- ggml_tensor * Q = dst->src[1];
156
- ggml_tensor * K = dst->src[1];
157
- ggml_tensor * V = dst->src[2];
158
-
159
- #ifdef GGML_CUDA_FA_ALL_QUANTS
160
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
161
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
162
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
166
-
167
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
173
-
174
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
180
-
181
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
187
-
188
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
194
-
195
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
201
-
202
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208
-
209
- FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210
- #else
211
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212
-
213
- FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214
-
215
- FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216
- FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217
- FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218
- #endif // GGML_CUDA_FA_ALL_QUANTS
219
-
220
- on_no_fattn_vec_case(Q->ne[0]);
221
- }
222
-
223
- #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
224
- if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
225
- ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
226
- return; \
227
- } \
228
-
229
- static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
- ggml_tensor * Q = dst->src[1];
231
- ggml_tensor * K = dst->src[1];
232
- ggml_tensor * V = dst->src[2];
233
-
234
- #ifdef GGML_CUDA_FA_ALL_QUANTS
235
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
236
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
237
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
238
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
239
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
240
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
241
-
242
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
243
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
244
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
245
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
246
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
247
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
248
-
249
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
250
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
251
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
252
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
253
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
254
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
255
-
256
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
257
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
258
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
259
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
260
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
261
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
262
-
263
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
264
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
265
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
266
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
267
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
268
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
269
-
270
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
271
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
272
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
273
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
274
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
275
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
276
-
277
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
278
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
279
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
280
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
281
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
282
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
283
-
284
- FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
285
- #else
286
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
287
-
288
- FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289
-
290
- FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
291
- FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
292
- FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
293
- #endif // GGML_CUDA_FA_ALL_QUANTS
294
-
295
- on_no_fattn_vec_case(Q->ne[0]);
296
- }
297
-
298
- void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
- const ggml_tensor * KQV = dst;
300
- const ggml_tensor * Q = dst->src[0];
301
-
302
- ggml_cuda_set_device(ctx.device);
303
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
- const int32_t precision = KQV->op_params[2];
305
-
306
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
- if (cc >= CC_OFFSET_AMD) {
308
- if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
- } else {
311
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
312
- }
313
- return;
314
- }
315
-
316
- if (!fast_fp16_available(cc)) {
317
- if (Q->ne[1] <= 8) {
318
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
- } else {
320
- ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
321
- }
322
- return;
323
- }
324
-
325
- if (!fp16_mma_available(cc)) {
326
- if (Q->ne[1] <= 8) {
327
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
328
- } else {
329
- ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
330
- }
331
- return;
332
- }
333
-
334
- if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
- if (precision == GGML_PREC_DEFAULT) {
336
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
- return;
338
- } else if(Q->ne[0] <= 128) {
339
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
340
- return;
341
- }
342
- }
343
-
344
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
- }
@@ -1,178 +0,0 @@
1
- #include "getrows.cuh"
2
- #include "dequantize.cuh"
3
-
4
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5
- static __global__ void k_get_rows(
6
- const void * src0, const int32_t * src1, dst_t * dst,
7
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
8
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
9
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
10
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
11
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
12
-
13
- const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15
- const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16
- const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
17
-
18
- if (i00 >= ne00) {
19
- return;
20
- }
21
-
22
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
23
-
24
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
26
-
27
- const int ib = i00/qk; // block index
28
- const int iqs = (i00%qk)/qr; // quant index
29
- const int iybs = i00 - i00%qk; // dst block start index
30
- const int y_offset = qr == 1 ? 1 : qk/2;
31
-
32
- // dequantize
33
- dfloat2 v;
34
- dequantize_kernel(src0_row, ib, iqs, v);
35
-
36
- dst_row[iybs + iqs + 0] = v.x;
37
- dst_row[iybs + iqs + y_offset] = v.y;
38
- }
39
-
40
- template<typename src0_t, typename dst_t>
41
- static __global__ void k_get_rows_float(
42
- const src0_t * src0, const int32_t * src1, dst_t * dst,
43
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
44
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
45
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
46
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
47
- size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
48
-
49
- const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50
- const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51
- const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52
- const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
53
-
54
- if (i00 >= ne00) {
55
- return;
56
- }
57
-
58
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
59
-
60
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
-
63
- dst_row[i00] = src0_row[i00];
64
- }
65
-
66
- template<int qk, int qr, dequantize_kernel_t dq>
67
- static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
68
- const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
69
-
70
- GGML_TENSOR_BINARY_OP_LOCALS
71
-
72
- const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
73
- const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
74
- const dim3 block_nums(block_num_x, ne10, ne11*ne12);
75
-
76
- // strides in elements
77
- //const size_t s0 = nb0 / ggml_element_size(dst);
78
- const size_t s1 = nb1 / ggml_element_size(dst);
79
- const size_t s2 = nb2 / ggml_element_size(dst);
80
- const size_t s3 = nb3 / ggml_element_size(dst);
81
-
82
- const size_t s10 = nb10 / ggml_element_size(src1);
83
- const size_t s11 = nb11 / ggml_element_size(src1);
84
- const size_t s12 = nb12 / ggml_element_size(src1);
85
- //const size_t s13 = nb13 / ggml_element_size(src1);
86
-
87
- GGML_ASSERT(ne00 % 2 == 0);
88
-
89
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
90
- src0_dd, src1_dd, dst_dd,
91
- ne00, /*ne01, ne02, ne03,*/
92
- /*ne10, ne11,*/ ne12, /*ne13,*/
93
- /* s0,*/ s1, s2, s3,
94
- /* nb00,*/ nb01, nb02, nb03,
95
- s10, s11, s12/*, s13*/);
96
-
97
- GGML_UNUSED(dst);
98
- }
99
-
100
- template<typename src0_t>
101
- static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
102
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
103
-
104
- GGML_TENSOR_BINARY_OP_LOCALS
105
-
106
- const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
107
- const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
108
- const dim3 block_nums(block_num_x, ne10, ne11*ne12);
109
-
110
- // strides in elements
111
- //const size_t s0 = nb0 / ggml_element_size(dst);
112
- const size_t s1 = nb1 / ggml_element_size(dst);
113
- const size_t s2 = nb2 / ggml_element_size(dst);
114
- const size_t s3 = nb3 / ggml_element_size(dst);
115
-
116
- const size_t s10 = nb10 / ggml_element_size(src1);
117
- const size_t s11 = nb11 / ggml_element_size(src1);
118
- const size_t s12 = nb12 / ggml_element_size(src1);
119
- //const size_t s13 = nb13 / ggml_element_size(src1);
120
-
121
- k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
122
- src0_dd, src1_dd, dst_dd,
123
- ne00, /*ne01, ne02, ne03,*/
124
- /*ne10, ne11,*/ ne12, /*ne13,*/
125
- /* s0,*/ s1, s2, s3,
126
- /* nb00,*/ nb01, nb02, nb03,
127
- s10, s11, s12/*, s13*/);
128
-
129
- GGML_UNUSED(dst);
130
- }
131
-
132
- void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
133
- const ggml_tensor * src0 = dst->src[0];
134
- const ggml_tensor * src1 = dst->src[1];
135
- const float * src0_d = (const float *)src0->data;
136
- const float * src1_d = (const float *)src1->data;
137
- float * dst_d = (float *)dst->data;
138
- cudaStream_t stream = ctx.stream();
139
-
140
-
141
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
142
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
143
-
144
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
145
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
146
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
147
-
148
- const int32_t * src1_i32 = (const int32_t *) src1_d;
149
-
150
- switch (src0->type) {
151
- case GGML_TYPE_F16:
152
- get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
153
- break;
154
- case GGML_TYPE_F32:
155
- get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
156
- break;
157
- case GGML_TYPE_Q4_0:
158
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
159
- break;
160
- case GGML_TYPE_Q4_1:
161
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
162
- break;
163
- case GGML_TYPE_Q5_0:
164
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
165
- break;
166
- case GGML_TYPE_Q5_1:
167
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
168
- break;
169
- case GGML_TYPE_Q8_0:
170
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
171
- break;
172
- default:
173
- // TODO: k-quants
174
- fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
175
- GGML_ASSERT(false);
176
- break;
177
- }
178
- }
@@ -1,104 +0,0 @@
1
- #include "im2col.cuh"
2
-
3
- template <typename T>
4
- static __global__ void im2col_kernel(
5
- const float * x, T * dst, int64_t batch_offset,
6
- int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
7
- int s0, int s1, int p0, int p1, int d0, int d1) {
8
- const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
9
- if (i >= pelements) {
10
- return;
11
- }
12
-
13
- const int64_t ksize = OW * (KH > 1 ? KW : 1);
14
- const int64_t kx = i / ksize;
15
- const int64_t kd = kx * ksize;
16
- const int64_t ky = (i - kd) / OW;
17
- const int64_t ix = i % OW;
18
-
19
- const int64_t oh = blockIdx.y;
20
- const int64_t batch = blockIdx.z / IC;
21
- const int64_t ic = blockIdx.z % IC;
22
-
23
- const int64_t iiw = ix * s0 + kx * d0 - p0;
24
- const int64_t iih = oh * s1 + ky * d1 - p1;
25
-
26
- const int64_t offset_dst =
27
- ((batch * OH + oh) * OW + ix) * CHW +
28
- (ic * (KW * KH) + ky * KW + kx);
29
-
30
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
31
- dst[offset_dst] = 0.0f;
32
- } else {
33
- const int64_t offset_src = ic * offset_delta + batch * batch_offset;
34
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
35
- }
36
- }
37
-
38
- template <typename T>
39
- static void im2col_cuda(const float * x, T* dst,
40
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
41
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
42
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
43
- const int parallel_elements = OW * KW * KH;
44
- const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
45
- dim3 block_nums(num_blocks, OH, batch * IC);
46
- im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
47
- }
48
-
49
- static void im2col_cuda_f16(const float * x, half * dst,
50
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
51
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
52
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
53
-
54
- im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
55
- }
56
-
57
- static void im2col_cuda_f32(const float * x, float * dst,
58
- int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
59
- int64_t batch, int64_t batch_offset, int64_t offset_delta,
60
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
61
-
62
- im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
63
- }
64
-
65
- void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
66
- const ggml_tensor * src0 = dst->src[0];
67
- const ggml_tensor * src1 = dst->src[1];
68
- const float * src1_d = (const float *)src1->data;
69
- float * dst_d = (float *)dst->data;
70
- cudaStream_t stream = ctx.stream();
71
-
72
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
73
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
74
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
75
-
76
- const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
77
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
78
- const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
79
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
80
- const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
81
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
82
-
83
- const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
84
-
85
- const int64_t IC = src1->ne[is_2D ? 2 : 1];
86
- const int64_t IH = is_2D ? src1->ne[1] : 1;
87
- const int64_t IW = src1->ne[0];
88
-
89
- const int64_t KH = is_2D ? src0->ne[1] : 1;
90
- const int64_t KW = src0->ne[0];
91
-
92
- const int64_t OH = is_2D ? dst->ne[2] : 1;
93
- const int64_t OW = dst->ne[1];
94
-
95
- const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
96
- const int64_t batch = src1->ne[3];
97
- const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
98
-
99
- if(dst->type == GGML_TYPE_F16) {
100
- im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
101
- } else {
102
- im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
103
- }
104
- }