llama_cpp 0.15.4 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,345 @@
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f16.cuh"
4
+ #include "fattn-tile-f32.cuh"
5
+ #include "fattn-vec-f16.cuh"
6
+ #include "fattn-vec-f32.cuh"
7
+ #include "fattn-wmma-f16.cuh"
8
+ #include "fattn.cuh"
9
+
10
+ #include <cstdint>
11
+
12
+ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const ggml_tensor * KQV = dst;
14
+ const ggml_tensor * Q = dst->src[0];
15
+
16
+ const int32_t precision = KQV->op_params[2];
17
+
18
+ if (precision != GGML_PREC_DEFAULT) {
19
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
+ constexpr int cols_per_block = 16;
21
+ switch (Q->ne[0]) {
22
+ case 64:
23
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
+ break;
25
+ case 80:
26
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
+ break;
28
+ case 96:
29
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
+ break;
31
+ case 112:
32
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
+ break;
34
+ case 128:
35
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
+ break;
37
+ case 256:
38
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
+ break;
40
+ default:
41
+ GGML_ASSERT(false);
42
+ break;
43
+ }
44
+ } else {
45
+ constexpr int cols_per_block = 32;
46
+ switch (Q->ne[0]) {
47
+ case 64:
48
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
+ break;
50
+ case 80:
51
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
+ break;
53
+ case 96:
54
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
+ break;
56
+ case 112:
57
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
+ break;
59
+ case 128:
60
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
+ break;
62
+ // case 256:
63
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
+ // break;
65
+ default:
66
+ GGML_ASSERT(false);
67
+ break;
68
+ }
69
+ }
70
+ return;
71
+ }
72
+
73
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
+ constexpr int cols_per_block = 8;
75
+ switch (Q->ne[0]) {
76
+ case 64:
77
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
+ break;
79
+ case 96:
80
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
+ break;
82
+ case 128:
83
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
+ break;
85
+ case 256:
86
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
+ break;
88
+ default:
89
+ GGML_ASSERT(false);
90
+ break;
91
+ }
92
+ return;
93
+ }
94
+
95
+ if (Q->ne[1] <= 32) {
96
+ constexpr int cols_per_block = 16;
97
+ switch (Q->ne[0]) {
98
+ case 64:
99
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
+ break;
101
+ case 80:
102
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
+ break;
104
+ case 96:
105
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
+ break;
107
+ case 112:
108
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
+ break;
110
+ case 128:
111
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
+ break;
113
+ case 256:
114
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
+ break;
116
+ default:
117
+ GGML_ASSERT(false);
118
+ break;
119
+ }
120
+ return;
121
+ }
122
+
123
+ constexpr int cols_per_block = 32;
124
+ switch (Q->ne[0]) {
125
+ case 64:
126
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
+ break;
128
+ case 80:
129
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
+ break;
131
+ case 96:
132
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
+ break;
134
+ case 112:
135
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
+ break;
137
+ case 128:
138
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
+ break;
140
+ case 256:
141
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
+ break;
143
+ default:
144
+ GGML_ASSERT(false);
145
+ break;
146
+ }
147
+ }
148
+ #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
+ ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
151
+ return; \
152
+ } \
153
+
154
+ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155
+ ggml_tensor * Q = dst->src[1];
156
+ ggml_tensor * K = dst->src[1];
157
+ ggml_tensor * V = dst->src[2];
158
+
159
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
160
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
161
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
162
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
166
+
167
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
173
+
174
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
180
+
181
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
187
+
188
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
194
+
195
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
201
+
202
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208
+
209
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210
+ #else
211
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212
+
213
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214
+
215
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218
+ #endif // GGML_CUDA_FA_ALL_QUANTS
219
+
220
+ on_no_fattn_vec_case(Q->ne[0]);
221
+ }
222
+
223
+ #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
224
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
225
+ ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
226
+ return; \
227
+ } \
228
+
229
+ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
+ ggml_tensor * Q = dst->src[1];
231
+ ggml_tensor * K = dst->src[1];
232
+ ggml_tensor * V = dst->src[2];
233
+
234
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
235
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
236
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
237
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
238
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
239
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
240
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
241
+
242
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
243
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
244
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
245
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
246
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
247
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
248
+
249
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
250
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
251
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
252
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
253
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
254
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
255
+
256
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
257
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
258
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
259
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
260
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
261
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
262
+
263
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
264
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
265
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
266
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
267
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
268
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
269
+
270
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
271
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
272
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
273
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
274
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
275
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
276
+
277
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
278
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
279
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
280
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
281
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
282
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
283
+
284
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
285
+ #else
286
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
287
+
288
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289
+
290
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
291
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
292
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
293
+ #endif // GGML_CUDA_FA_ALL_QUANTS
294
+
295
+ on_no_fattn_vec_case(Q->ne[0]);
296
+ }
297
+
298
+ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
+ const ggml_tensor * KQV = dst;
300
+ const ggml_tensor * Q = dst->src[0];
301
+
302
+ ggml_cuda_set_device(ctx.device);
303
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
+ const int32_t precision = KQV->op_params[2];
305
+
306
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
+ if (cc >= CC_OFFSET_AMD) {
308
+ if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
+ } else {
311
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
312
+ }
313
+ return;
314
+ }
315
+
316
+ if (!fast_fp16_available(cc)) {
317
+ if (Q->ne[1] <= 8) {
318
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
+ } else {
320
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
321
+ }
322
+ return;
323
+ }
324
+
325
+ if (!fp16_mma_available(cc)) {
326
+ if (Q->ne[1] <= 8) {
327
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
328
+ } else {
329
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
330
+ }
331
+ return;
332
+ }
333
+
334
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
+ if (precision == GGML_PREC_DEFAULT) {
336
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
+ return;
338
+ } else if(Q->ne[0] <= 128) {
339
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
340
+ return;
341
+ }
342
+ }
343
+
344
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
+ }
@@ -0,0 +1,178 @@
1
+ #include "getrows.cuh"
2
+ #include "dequantize.cuh"
3
+
4
+ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5
+ static __global__ void k_get_rows(
6
+ const void * src0, const int32_t * src1, dst_t * dst,
7
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
8
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
9
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
10
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
11
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
12
+
13
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
14
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
15
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
16
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
17
+
18
+ if (i00 >= ne00) {
19
+ return;
20
+ }
21
+
22
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
23
+
24
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
25
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
26
+
27
+ const int ib = i00/qk; // block index
28
+ const int iqs = (i00%qk)/qr; // quant index
29
+ const int iybs = i00 - i00%qk; // dst block start index
30
+ const int y_offset = qr == 1 ? 1 : qk/2;
31
+
32
+ // dequantize
33
+ dfloat2 v;
34
+ dequantize_kernel(src0_row, ib, iqs, v);
35
+
36
+ dst_row[iybs + iqs + 0] = v.x;
37
+ dst_row[iybs + iqs + y_offset] = v.y;
38
+ }
39
+
40
+ template<typename src0_t, typename dst_t>
41
+ static __global__ void k_get_rows_float(
42
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
43
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
44
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
45
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
46
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
47
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
48
+
49
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
50
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
51
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
52
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
53
+
54
+ if (i00 >= ne00) {
55
+ return;
56
+ }
57
+
58
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
59
+
60
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
+
63
+ dst_row[i00] = src0_row[i00];
64
+ }
65
+
66
+ template<int qk, int qr, dequantize_kernel_t dq>
67
+ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
68
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
69
+
70
+ GGML_TENSOR_BINARY_OP_LOCALS
71
+
72
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
73
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
74
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
75
+
76
+ // strides in elements
77
+ //const size_t s0 = nb0 / ggml_element_size(dst);
78
+ const size_t s1 = nb1 / ggml_element_size(dst);
79
+ const size_t s2 = nb2 / ggml_element_size(dst);
80
+ const size_t s3 = nb3 / ggml_element_size(dst);
81
+
82
+ const size_t s10 = nb10 / ggml_element_size(src1);
83
+ const size_t s11 = nb11 / ggml_element_size(src1);
84
+ const size_t s12 = nb12 / ggml_element_size(src1);
85
+ //const size_t s13 = nb13 / ggml_element_size(src1);
86
+
87
+ GGML_ASSERT(ne00 % 2 == 0);
88
+
89
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
90
+ src0_dd, src1_dd, dst_dd,
91
+ ne00, /*ne01, ne02, ne03,*/
92
+ /*ne10, ne11,*/ ne12, /*ne13,*/
93
+ /* s0,*/ s1, s2, s3,
94
+ /* nb00,*/ nb01, nb02, nb03,
95
+ s10, s11, s12/*, s13*/);
96
+
97
+ GGML_UNUSED(dst);
98
+ }
99
+
100
+ template<typename src0_t>
101
+ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
102
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
103
+
104
+ GGML_TENSOR_BINARY_OP_LOCALS
105
+
106
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
107
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
108
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
109
+
110
+ // strides in elements
111
+ //const size_t s0 = nb0 / ggml_element_size(dst);
112
+ const size_t s1 = nb1 / ggml_element_size(dst);
113
+ const size_t s2 = nb2 / ggml_element_size(dst);
114
+ const size_t s3 = nb3 / ggml_element_size(dst);
115
+
116
+ const size_t s10 = nb10 / ggml_element_size(src1);
117
+ const size_t s11 = nb11 / ggml_element_size(src1);
118
+ const size_t s12 = nb12 / ggml_element_size(src1);
119
+ //const size_t s13 = nb13 / ggml_element_size(src1);
120
+
121
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
122
+ src0_dd, src1_dd, dst_dd,
123
+ ne00, /*ne01, ne02, ne03,*/
124
+ /*ne10, ne11,*/ ne12, /*ne13,*/
125
+ /* s0,*/ s1, s2, s3,
126
+ /* nb00,*/ nb01, nb02, nb03,
127
+ s10, s11, s12/*, s13*/);
128
+
129
+ GGML_UNUSED(dst);
130
+ }
131
+
132
+ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
133
+ const ggml_tensor * src0 = dst->src[0];
134
+ const ggml_tensor * src1 = dst->src[1];
135
+ const float * src0_d = (const float *)src0->data;
136
+ const float * src1_d = (const float *)src1->data;
137
+ float * dst_d = (float *)dst->data;
138
+ cudaStream_t stream = ctx.stream();
139
+
140
+
141
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
142
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
143
+
144
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
145
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
146
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
147
+
148
+ const int32_t * src1_i32 = (const int32_t *) src1_d;
149
+
150
+ switch (src0->type) {
151
+ case GGML_TYPE_F16:
152
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
153
+ break;
154
+ case GGML_TYPE_F32:
155
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
156
+ break;
157
+ case GGML_TYPE_Q4_0:
158
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
159
+ break;
160
+ case GGML_TYPE_Q4_1:
161
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
162
+ break;
163
+ case GGML_TYPE_Q5_0:
164
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
165
+ break;
166
+ case GGML_TYPE_Q5_1:
167
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
168
+ break;
169
+ case GGML_TYPE_Q8_0:
170
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
171
+ break;
172
+ default:
173
+ // TODO: k-quants
174
+ fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
175
+ GGML_ASSERT(false);
176
+ break;
177
+ }
178
+ }
@@ -0,0 +1,104 @@
1
+ #include "im2col.cuh"
2
+
3
+ template <typename T>
4
+ static __global__ void im2col_kernel(
5
+ const float * x, T * dst, int64_t batch_offset,
6
+ int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
7
+ int s0, int s1, int p0, int p1, int d0, int d1) {
8
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
9
+ if (i >= pelements) {
10
+ return;
11
+ }
12
+
13
+ const int64_t ksize = OW * (KH > 1 ? KW : 1);
14
+ const int64_t kx = i / ksize;
15
+ const int64_t kd = kx * ksize;
16
+ const int64_t ky = (i - kd) / OW;
17
+ const int64_t ix = i % OW;
18
+
19
+ const int64_t oh = blockIdx.y;
20
+ const int64_t batch = blockIdx.z / IC;
21
+ const int64_t ic = blockIdx.z % IC;
22
+
23
+ const int64_t iiw = ix * s0 + kx * d0 - p0;
24
+ const int64_t iih = oh * s1 + ky * d1 - p1;
25
+
26
+ const int64_t offset_dst =
27
+ ((batch * OH + oh) * OW + ix) * CHW +
28
+ (ic * (KW * KH) + ky * KW + kx);
29
+
30
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
31
+ dst[offset_dst] = 0.0f;
32
+ } else {
33
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
34
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
35
+ }
36
+ }
37
+
38
+ template <typename T>
39
+ static void im2col_cuda(const float * x, T* dst,
40
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
41
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
42
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
43
+ const int parallel_elements = OW * KW * KH;
44
+ const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
45
+ dim3 block_nums(num_blocks, OH, batch * IC);
46
+ im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
47
+ }
48
+
49
+ static void im2col_cuda_f16(const float * x, half * dst,
50
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
51
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
52
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
53
+
54
+ im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
55
+ }
56
+
57
+ static void im2col_cuda_f32(const float * x, float * dst,
58
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
59
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
60
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
61
+
62
+ im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
63
+ }
64
+
65
+ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
66
+ const ggml_tensor * src0 = dst->src[0];
67
+ const ggml_tensor * src1 = dst->src[1];
68
+ const float * src1_d = (const float *)src1->data;
69
+ float * dst_d = (float *)dst->data;
70
+ cudaStream_t stream = ctx.stream();
71
+
72
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
73
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
74
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
75
+
76
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
77
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
78
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
79
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
80
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
81
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
82
+
83
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
84
+
85
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
86
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
87
+ const int64_t IW = src1->ne[0];
88
+
89
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
90
+ const int64_t KW = src0->ne[0];
91
+
92
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
93
+ const int64_t OW = dst->ne[1];
94
+
95
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
96
+ const int64_t batch = src1->ne[3];
97
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
98
+
99
+ if(dst->type == GGML_TYPE_F16) {
100
+ im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
101
+ } else {
102
+ im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
103
+ }
104
+ }