llama_cpp 0.16.2 → 0.17.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -12
  4. data/ext/llama_cpp/extconf.rb +2 -43
  5. data/ext/llama_cpp/llama_cpp.cpp +8 -0
  6. data/lib/llama_cpp/version.rb +3 -3
  7. data/sig/llama_cpp.rbs +3 -0
  8. metadata +2 -171
  9. data/vendor/include/.gitkeep +0 -0
  10. data/vendor/lib/.gitkeep +0 -0
  11. data/vendor/tmp/llama.cpp/LICENSE +0 -21
  12. data/vendor/tmp/llama.cpp/Makefile +0 -1124
  13. data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
  14. data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
  15. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
  16. data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
  17. data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
  18. data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
  19. data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
  20. data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
  21. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
  22. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
  23. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
  24. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
  25. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
  26. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
  27. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
  28. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
  29. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
  30. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
  31. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
  32. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
  33. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
  34. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
  35. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
  36. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
  37. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
  38. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
  39. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
  40. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
  41. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
  42. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
  43. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
  44. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
  45. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  127. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  128. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  129. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  130. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  131. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  132. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  133. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  134. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  135. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  136. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  137. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
  138. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
  139. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
  140. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
  141. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
  142. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
  143. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
  144. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
  145. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
  146. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
  147. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
  148. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
  149. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
  150. data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
  151. data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
  152. data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
  153. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
  154. data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
  155. data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
  156. data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
  157. data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
  158. data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
  159. data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
  160. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
  161. data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
  162. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
  163. data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
  164. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
  165. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
  166. data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
  167. data/vendor/tmp/llama.cpp/ggml.c +0 -22506
  168. data/vendor/tmp/llama.cpp/ggml.h +0 -2458
  169. data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
  170. data/vendor/tmp/llama.cpp/llama.h +0 -1147
  171. data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
  172. data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
  173. data/vendor/tmp/llama.cpp/sgemm.h +0 -14
  174. data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
  175. data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
  176. data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
  177. data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,2038 +0,0 @@
1
- #include "ggml.h"
2
- #include "ggml-backend.h"
3
- #include "ggml-backend-impl.h"
4
- #include "ggml-kompute.h"
5
-
6
- // These are generated at build time by cmake custom command
7
- #include "shaderop_scale.h"
8
- #include "shaderop_scale_8.h"
9
- #include "shaderop_add.h"
10
- #include "shaderop_addrow.h"
11
- #include "shaderop_mul.h"
12
- #include "shaderop_silu.h"
13
- #include "shaderop_relu.h"
14
- #include "shaderop_gelu.h"
15
- #include "shaderop_softmax.h"
16
- #include "shaderop_norm.h"
17
- #include "shaderop_rmsnorm.h"
18
- #include "shaderop_diagmask.h"
19
- #include "shaderop_mul_mat_f16.h"
20
- #include "shaderop_mul_mat_q8_0.h"
21
- #include "shaderop_mul_mat_q4_0.h"
22
- #include "shaderop_mul_mat_q4_1.h"
23
- #include "shaderop_mul_mat_q6_k.h"
24
- #include "shaderop_mul_mat_mat_f32.h"
25
- #include "shaderop_getrows_f32.h"
26
- #include "shaderop_getrows_f16.h"
27
- #include "shaderop_getrows_q4_0.h"
28
- #include "shaderop_getrows_q4_1.h"
29
- #include "shaderop_getrows_q6_k.h"
30
- #include "shaderop_rope_f16.h"
31
- #include "shaderop_rope_f32.h"
32
- #include "shaderop_cpy_f16_f16.h"
33
- #include "shaderop_cpy_f16_f32.h"
34
- #include "shaderop_cpy_f32_f16.h"
35
- #include "shaderop_cpy_f32_f32.h"
36
-
37
- #include <algorithm>
38
- #include <array>
39
- #include <cassert>
40
- #include <cstdint>
41
- #include <cstdio>
42
- #include <cstring>
43
- #include <iostream>
44
- #include <memory>
45
- #include <stdexcept>
46
- #include <string>
47
- #include <unordered_map>
48
- #include <utility>
49
- #include <vector>
50
-
51
- #include <kompute/Kompute.hpp>
52
- #include <vulkan/vulkan.hpp>
53
-
54
- #ifdef __linux__
55
- #include <cstdlib> // for setenv
56
- #endif
57
-
58
- #define QK4_0 32
59
- #define QR4_0 2
60
- #define QK4_1 32
61
- #define QK_NL 16
62
-
63
- typedef ggml_fp16_t half;
64
-
65
- static std::string ggml_kompute_format_name(int device) {
66
- return "Kompute" + std::to_string(device);
67
- }
68
-
69
- struct ggml_kompute_context {
70
- int device;
71
- std::string name;
72
- std::shared_ptr<vk::DescriptorPool> pool;
73
-
74
- ggml_kompute_context(int device)
75
- : device(device), name(ggml_kompute_format_name(device)) {}
76
- };
77
-
78
- // FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
79
- // and consolidate the init functions and simplify object lifetime management. As it currently stands,
80
- // we *have* to have the kompute manager no matter what for device discovery, but the kompute context
81
- // is only created when a device is set and vulkan is explicitly turned on.
82
- static ggml_kompute_context *s_kompute_context = nullptr;
83
-
84
- class kompute_manager {
85
- kp::Manager *s_mgr = nullptr;
86
-
87
- public:
88
- kp::Manager *operator()() {
89
- if (s_mgr && !s_mgr->hasInstance()) {
90
- destroy();
91
- }
92
- if (!s_mgr) {
93
- s_mgr = new kp::Manager;
94
- }
95
- return s_mgr;
96
- }
97
-
98
- void destroy() {
99
- delete s_mgr;
100
- s_mgr = nullptr;
101
- }
102
- };
103
-
104
- static kompute_manager komputeManager;
105
-
106
- struct ggml_vk_memory {
107
- void *data = nullptr;
108
- size_t size = 0;
109
- vk::DeviceMemory *primaryMemory = nullptr;
110
- vk::Buffer *primaryBuffer = nullptr;
111
- vk::DeviceMemory *stagingMemory = nullptr;
112
- vk::Buffer *stagingBuffer = nullptr;
113
- };
114
-
115
- #ifdef __linux__
116
- __attribute__((constructor))
117
- static void enable_sam() {
118
- setenv("RADV_PERFTEST", "sam", false);
119
- }
120
- #endif
121
-
122
- static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
123
- vk::PhysicalDeviceFeatures availableFeatures;
124
- physical_device.getFeatures(&availableFeatures);
125
-
126
- if (!availableFeatures.shaderInt16)
127
- return false;
128
-
129
- vk::PhysicalDeviceVulkan11Features availableFeatures11;
130
- vk::PhysicalDeviceVulkan12Features availableFeatures12;
131
-
132
- availableFeatures11.pNext = &availableFeatures12;
133
- availableFeatures12.pNext = nullptr;
134
-
135
- vk::PhysicalDeviceFeatures2 features2;
136
- features2.pNext = &availableFeatures11;
137
-
138
- physical_device.getFeatures2(&features2);
139
-
140
- if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
141
- !availableFeatures11.storageBuffer16BitAccess) {
142
- return false;
143
- }
144
-
145
- if (!availableFeatures12.storageBuffer8BitAccess ||
146
- !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
147
- !availableFeatures12.shaderFloat16 ||
148
- !availableFeatures12.shaderInt8) {
149
- return false;
150
- }
151
-
152
- return true;
153
- }
154
-
155
- static const char * ggml_vk_getVendorName(uint32_t vendorID) {
156
- switch (vendorID) {
157
- case 0x10DE:
158
- return "nvidia";
159
- case 0x1002:
160
- return "amd";
161
- case 0x8086:
162
- return "intel";
163
- default:
164
- return "unknown";
165
- }
166
- }
167
-
168
- static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
169
- std::vector<ggml_vk_device> results;
170
- if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
171
- return results;
172
-
173
- std::vector<vk::PhysicalDevice> physical_devices;
174
- try {
175
- physical_devices = komputeManager()->listDevices();
176
- } catch (vk::SystemError & err) {
177
- std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
178
- return results;
179
- }
180
-
181
- uint32_t deviceCount = physical_devices.size();
182
- if (deviceCount == 0)
183
- return results;
184
-
185
- std::unordered_map<std::string, size_t> count_by_name;
186
-
187
- for (uint32_t i = 0; i < deviceCount; i++) {
188
- const auto & physical_device = physical_devices[i];
189
-
190
- VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
191
- VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
192
- const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
193
- const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
194
- if (major < 1 || minor < 2)
195
- continue;
196
-
197
- if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
198
- continue;
199
-
200
- size_t heapSize = 0;
201
- for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
202
- VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
203
- if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
204
- heapSize = heap.size;
205
- break;
206
- }
207
- }
208
-
209
- if (heapSize < memoryRequired)
210
- continue;
211
-
212
- auto ext_props = physical_device.enumerateDeviceExtensionProperties();
213
- bool has_maintenance4 = false;
214
-
215
- // Check if maintenance4 is supported
216
- for (const auto & properties : ext_props) {
217
- if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
218
- has_maintenance4 = true;
219
- }
220
- }
221
-
222
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
223
- vk::PhysicalDeviceProperties2 dev_props2;
224
- vk::PhysicalDeviceMaintenance3Properties dev_props3;
225
- vk::PhysicalDeviceMaintenance4Properties dev_props4;
226
- dev_props2.pNext = &dev_props3;
227
- dev_props3.pNext = &subgroup_props;
228
- if (has_maintenance4) {
229
- subgroup_props.pNext = &dev_props4;
230
- }
231
- physical_device.getProperties2(&dev_props2);
232
-
233
- if (subgroup_props.subgroupSize < 32)
234
- continue;
235
-
236
- ggml_vk_device d;
237
- d.index = i;
238
- d.type = dev_props.deviceType;
239
- d.heapSize = heapSize;
240
- d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
241
- d.subgroupSize = subgroup_props.subgroupSize;
242
- d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
243
-
244
- if (has_maintenance4) {
245
- d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
246
- } else {
247
- d.maxAlloc = dev_props3.maxMemoryAllocationSize;
248
- }
249
-
250
- std::string name(dev_props.deviceName);
251
- size_t n_idx = ++count_by_name[name];
252
- if (n_idx > 1) {
253
- name += " (" + std::to_string(n_idx) + ")";
254
- }
255
- d.name = strdup(name.c_str());
256
-
257
- results.push_back(d);
258
- }
259
-
260
- std::stable_sort(results.begin(), results.end(),
261
- [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
262
- if (lhs.type != rhs.type) {
263
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
264
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
265
-
266
- if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
267
- if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
268
- }
269
- return lhs.heapSize < rhs.heapSize;
270
- }
271
- );
272
-
273
- return results;
274
- }
275
-
276
- // public API returns a C-style array
277
- ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
278
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
279
- *count = devices.size();
280
- if (devices.empty()) {
281
- return nullptr;
282
- }
283
-
284
- size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
285
- auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
286
- memcpy(arr, devices.data(), nbytes);
287
- return arr;
288
- }
289
-
290
- static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
291
- devices.erase(
292
- std::remove_if(devices.begin(), devices.end(),
293
- [&targetVendor](const ggml_vk_device& device) {
294
- return device.vendor != targetVendor;
295
- }),
296
- devices.end()
297
- );
298
- }
299
-
300
- static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
301
- devices.erase(
302
- std::remove_if(devices.begin(), devices.end(),
303
- [&targetName](const ggml_vk_device& device) {
304
- return device.name != targetName;
305
- }),
306
- devices.end()
307
- );
308
- }
309
-
310
- static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
311
- if (name.empty())
312
- return false;
313
-
314
- auto devices = ggml_vk_available_devices_internal(memoryRequired);
315
- if (name == "amd" || name == "nvidia" || name == "intel") {
316
- ggml_vk_filterByVendor(devices, name);
317
- } else if (name != "gpu") {
318
- ggml_vk_filterByName(devices, name);
319
- }
320
-
321
- if (devices.empty())
322
- return false;
323
-
324
- *device = devices.front();
325
- return true;
326
- }
327
-
328
- bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
329
- return ggml_vk_get_device(device, memoryRequired, std::string(name));
330
- }
331
-
332
- bool ggml_vk_has_vulkan() {
333
- return komputeManager()->hasVulkan();
334
- }
335
-
336
- bool ggml_vk_has_device() {
337
- return komputeManager()->hasDevice();
338
- }
339
-
340
- ggml_vk_device ggml_vk_current_device() {
341
- if (!komputeManager()->hasDevice())
342
- return ggml_vk_device();
343
-
344
- auto devices = ggml_vk_available_devices_internal(0);
345
- ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
346
- GGML_ASSERT(!devices.empty());
347
- return devices.front();
348
- }
349
-
350
- static
351
- void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
352
- std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
353
- vk::DescriptorPoolSize(
354
- vk::DescriptorType::eStorageBuffer,
355
- 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
356
- )
357
- };
358
-
359
- vk::DescriptorPoolCreateInfo descriptorPoolInfo(
360
- vk::DescriptorPoolCreateFlags(),
361
- size, // Max sets
362
- static_cast<uint32_t>(descriptorPoolSizes.size()),
363
- descriptorPoolSizes.data());
364
-
365
- ctx->pool = std::make_shared<vk::DescriptorPool>();
366
- vk::Result r = komputeManager()->device()->createDescriptorPool(
367
- &descriptorPoolInfo, nullptr, ctx->pool.get());
368
- if (r != vk::Result::eSuccess)
369
- std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
370
- }
371
-
372
- static
373
- void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
374
- if (ctx->pool) {
375
- komputeManager()->device()->destroy(
376
- *ctx->pool,
377
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
378
- ctx->pool = nullptr;
379
- }
380
- }
381
-
382
- static
383
- vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
384
- vk::BufferCreateInfo bufferCreateInfo;
385
- bufferCreateInfo.size = size;
386
- bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
387
- vk::BufferUsageFlagBits::eTransferSrc |
388
- vk::BufferUsageFlagBits::eTransferDst;
389
- bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
390
-
391
- vk::Buffer *vkBuffer = new vk::Buffer;
392
- vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
393
- if (r != vk::Result::eSuccess)
394
- std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
395
- return vkBuffer;
396
- }
397
-
398
- static
399
- vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
400
-
401
- uint32_t memoryTypeIndex = -1;
402
- bool memoryTypeIndexFound = false;
403
- vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
404
- for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
405
- const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
406
- const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
407
- if (memoryHeap.size < size) {
408
- continue;
409
- }
410
-
411
- if (requirements.memoryTypeBits & (1 << i)) {
412
- if (((memoryProperties.memoryTypes[i]).propertyFlags &
413
- flags) == flags) {
414
- memoryTypeIndex = i;
415
- memoryTypeIndexFound = true;
416
- if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
417
- *isHostVisible = true;
418
- }
419
- break;
420
- }
421
- }
422
- }
423
- if (!memoryTypeIndexFound) {
424
- throw std::runtime_error(
425
- "Memory type index for buffer creation not found");
426
- }
427
-
428
- vk::MemoryAllocateInfo allocInfo;
429
- allocInfo.allocationSize = size;
430
- allocInfo.memoryTypeIndex = memoryTypeIndex;
431
- vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory;
432
- vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
433
- if (r != vk::Result::eSuccess) {
434
- std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
435
- throw std::runtime_error("Error allocating vulkan memory.");
436
- }
437
- return vkDeviceMemory;
438
- }
439
-
440
- static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
441
- size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
442
-
443
- // If offset is already aligned, return it directly
444
- if (offset % minStorageBufferOffsetAlignment == 0) {
445
- return offset;
446
- }
447
-
448
- // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
449
- return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
450
- }
451
-
452
- static ggml_vk_memory ggml_vk_allocate(size_t size) {
453
- ggml_vk_memory memory;
454
- bool isHostVisible = false;
455
- {
456
- memory.primaryBuffer = ggml_vk_allocate_buffer(size);
457
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
458
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
459
- memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
460
- komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
461
- if (isHostVisible) {
462
- vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
463
- if (r != vk::Result::eSuccess)
464
- std::cerr << "Error mapping memory" << vk::to_string(r);
465
- }
466
- }
467
-
468
- if (!isHostVisible) {
469
- memory.stagingBuffer = ggml_vk_allocate_buffer(size);
470
- vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
471
- vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
472
- vk::MemoryPropertyFlagBits::eHostCoherent |
473
- vk::MemoryPropertyFlagBits::eHostCached;
474
- memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
475
- komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
476
- vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
477
- if (r != vk::Result::eSuccess)
478
- std::cerr << "Error mapping memory" << vk::to_string(r);
479
- }
480
-
481
- memory.size = size;
482
- return memory;
483
- }
484
-
485
- static void ggml_vk_free_memory(ggml_vk_memory &memory)
486
- {
487
- komputeManager()->device()->destroy(
488
- *memory.primaryBuffer,
489
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
490
- if (memory.stagingBuffer) {
491
- komputeManager()->device()->destroy(
492
- *memory.stagingBuffer,
493
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
494
- }
495
- komputeManager()->device()->freeMemory(
496
- *memory.primaryMemory,
497
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
498
- if (memory.stagingMemory) {
499
- komputeManager()->device()->freeMemory(
500
- *memory.stagingMemory,
501
- (vk::Optional<const vk::AllocationCallbacks>)nullptr);
502
- }
503
- }
504
-
505
- static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
506
-
507
- static
508
- ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
509
- ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
510
-
511
- // compatibility with ggml-backend
512
- GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
513
-
514
- ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
515
-
516
- const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
517
-
518
- GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
519
-
520
- offset = uint64_t(ioffs);
521
- return buf_ctx;
522
- }
523
-
524
- static
525
- const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
526
- uint64_t originalOffset = 0;
527
- auto * res = ggml_vk_find_tensor(t, originalOffset);
528
- if (!res) {
529
- static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
530
- return nullTensor;
531
- }
532
-
533
- // Create a tensor whose memory will be composed of our buffers at the correct offset
534
- const size_t nelements = ggml_nelements(t);
535
- size_t nbytes = ggml_nbytes(t);
536
-
537
- size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
538
- if (alignedOffset) {
539
- *alignedOffset = originalOffset - vulkanOffset;
540
- nbytes += *alignedOffset;
541
- }
542
-
543
- return komputeManager()->tensor(
544
- t->data,
545
- nelements,
546
- nbytes, kp::Tensor::TensorDataTypes::eFloat,
547
- res->primaryMemory, res->primaryBuffer,
548
- res->stagingMemory, res->stagingBuffer,
549
- vulkanOffset);
550
- }
551
-
552
- static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
553
- if (size % sizeof(uint32_t) != 0) {
554
- throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
555
- }
556
-
557
- const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
558
- size_t count = size / sizeof(uint32_t);
559
- return std::vector<uint32_t>(data_ptr, data_ptr + count);
560
- }
561
-
562
- inline static
563
- uint32_t safe_divide(uint32_t a, uint32_t b) {
564
- if (b <= 1) {
565
- return a;
566
- }
567
- if ((a % b) != 0) {
568
- fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
569
- GGML_ASSERT(!"safe_divide result would've had remainder");
570
- }
571
- return a / b;
572
- }
573
-
574
- static void ggml_vk_add(
575
- kp::Sequence& seq,
576
- const std::shared_ptr<kp::Tensor>& inA,
577
- const std::shared_ptr<kp::Tensor>& inB,
578
- const std::shared_ptr<kp::Tensor>& out,
579
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
580
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
581
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
582
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
583
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
584
- int32_t ne0,
585
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
586
- ) {
587
- const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
588
- kp::shader_data::op_add_comp_spv_len);
589
-
590
- struct PushConstants {
591
- uint32_t inAOff, inBOff, outOff;
592
- int32_t ne00;
593
- int32_t nb00, nb01, nb02, nb03;
594
- int32_t ne10, ne11, ne12, ne13;
595
- int32_t nb10, nb11, nb12, nb13;
596
- int32_t ne0;
597
- int32_t nb0, nb1, nb2, nb3;
598
- } const pushConsts {
599
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
600
- ne00,
601
- nb00, nb01, nb02, nb03,
602
- ne10, ne11, ne12, ne13,
603
- nb10, nb11, nb12, nb13,
604
- ne0,
605
- nb0, nb1, nb2, nb3
606
- };
607
-
608
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
609
- if (!komputeManager()->hasAlgorithm(__func__)) {
610
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
611
- } else {
612
- s_algo = komputeManager()->getAlgorithm(__func__);
613
- s_algo->setTensors({inA, inB, out});
614
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
615
- s_algo->setPushConstants<PushConstants>({pushConsts});
616
- s_algo->updateDescriptors(s_kompute_context->pool.get());
617
- }
618
- seq.record<kp::OpAlgoDispatch>(s_algo);
619
- }
620
-
621
- static void ggml_vk_addrow(kp::Sequence& seq,
622
- const std::shared_ptr<kp::Tensor>& inA,
623
- const std::shared_ptr<kp::Tensor>& inB,
624
- const std::shared_ptr<kp::Tensor>& out,
625
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
626
- uint32_t size, uint32_t row = 0) {
627
-
628
- const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
629
- kp::shader_data::op_addrow_comp_spv_len);
630
-
631
- struct PushConstants {
632
- uint32_t inAOff, inBOff, outOff;
633
- uint32_t row;
634
- } const pushConsts {
635
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
636
- row
637
- };
638
-
639
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
640
- if (!komputeManager()->hasAlgorithm(__func__))
641
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
642
- else {
643
- s_algo = komputeManager()->getAlgorithm(__func__);
644
- s_algo->setTensors({inA, inB, out});
645
- s_algo->setWorkgroup({size});
646
- s_algo->setPushConstants<PushConstants>({pushConsts});
647
- s_algo->updateDescriptors(s_kompute_context->pool.get());
648
- }
649
- seq.record<kp::OpAlgoDispatch>(s_algo);
650
- }
651
-
652
- static void ggml_vk_mul(
653
- kp::Sequence& seq,
654
- const std::shared_ptr<kp::Tensor>& inA,
655
- const std::shared_ptr<kp::Tensor>& inB,
656
- const std::shared_ptr<kp::Tensor>& out,
657
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
658
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
659
- int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
660
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
661
- int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
662
- int32_t ne0,
663
- int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
664
- ) {
665
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
666
- kp::shader_data::op_mul_comp_spv_len);
667
-
668
- struct PushConstants {
669
- uint32_t inAOff, inBOff, outOff;
670
- int32_t ne00;
671
- int32_t nb00, nb01, nb02, nb03;
672
- int32_t ne10, ne11, ne12, ne13;
673
- int32_t nb10, nb11, nb12, nb13;
674
- int32_t ne0;
675
- int32_t nb0, nb1, nb2, nb3;
676
- } const pushConsts {
677
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
678
- ne00,
679
- nb00, nb01, nb02, nb03,
680
- ne10, ne11, ne12, ne13,
681
- nb10, nb11, nb12, nb13,
682
- ne0,
683
- nb0, nb1, nb2, nb3
684
- };
685
-
686
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
687
- if (!komputeManager()->hasAlgorithm(__func__)) {
688
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
689
- } else {
690
- s_algo = komputeManager()->getAlgorithm(__func__);
691
- s_algo->setTensors({inA, inB, out});
692
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
693
- s_algo->setPushConstants<PushConstants>({pushConsts});
694
- s_algo->updateDescriptors(s_kompute_context->pool.get());
695
- }
696
- seq.record<kp::OpAlgoDispatch>(s_algo);
697
- }
698
-
699
- static void ggml_vk_scale(kp::Sequence& seq,
700
- const std::shared_ptr<kp::Tensor>& in,
701
- const std::shared_ptr<kp::Tensor>& out,
702
- uint32_t inOff, uint32_t outOff,
703
- uint32_t size, float scale) {
704
- const static auto spirv_1 = getSpirvShader(
705
- kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
706
- );
707
- const static auto spirv_8 = getSpirvShader(
708
- kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
709
- );
710
-
711
- struct PushConstants {
712
- uint32_t inOff, outOff;
713
- float scale;
714
- } const pushConsts {
715
- safe_divide(inOff, 4), safe_divide(outOff, 4),
716
- scale
717
- };
718
-
719
- const auto * spirv = &spirv_1;
720
- std::string name(__func__);
721
- if (size % 8 == 0) {
722
- size /= 8;
723
- name += "_8";
724
- spirv = &spirv_8;
725
- }
726
-
727
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
728
- if (!komputeManager()->hasAlgorithm(name)) {
729
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
730
- } else {
731
- s_algo = komputeManager()->getAlgorithm(name);
732
- s_algo->setTensors({in, out});
733
- s_algo->setWorkgroup({size});
734
- s_algo->setPushConstants<PushConstants>({pushConsts});
735
- s_algo->updateDescriptors(s_kompute_context->pool.get());
736
- }
737
- seq.record<kp::OpAlgoDispatch>(s_algo);
738
- }
739
-
740
- static void ggml_vk_xxlu(
741
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
742
- const std::shared_ptr<kp::Tensor>& in,
743
- const std::shared_ptr<kp::Tensor>& out,
744
- uint32_t inOff, uint32_t outOff,
745
- uint32_t size
746
- ) {
747
- struct PushConstants {
748
- uint32_t inOff, outOff;
749
- } const pushConsts {
750
- safe_divide(inOff, 4), safe_divide(outOff, 4),
751
- };
752
-
753
- auto name = std::string(__func__) + "_" + suffix;
754
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
755
- if (!komputeManager()->hasAlgorithm(name)) {
756
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
757
- } else {
758
- s_algo = komputeManager()->getAlgorithm(name);
759
- s_algo->setTensors({in, out});
760
- s_algo->setWorkgroup({size});
761
- s_algo->setPushConstants<PushConstants>({pushConsts});
762
- s_algo->updateDescriptors(s_kompute_context->pool.get());
763
- }
764
- seq.record<kp::OpAlgoDispatch>(s_algo);
765
- }
766
-
767
- template <typename... Args>
768
- static void ggml_vk_silu(Args&&... args) {
769
- const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
770
- kp::shader_data::op_silu_comp_spv_len);
771
-
772
- ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
773
- }
774
-
775
- template <typename... Args>
776
- static void ggml_vk_relu(Args&&... args) {
777
- const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
778
- kp::shader_data::op_relu_comp_spv_len);
779
-
780
- ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
781
- }
782
-
783
- template <typename... Args>
784
- static void ggml_vk_gelu(Args&&... args) {
785
- const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
786
- kp::shader_data::op_gelu_comp_spv_len);
787
-
788
- ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
789
- }
790
-
791
- static void ggml_vk_soft_max(
792
- kp::Sequence& seq,
793
- const std::shared_ptr<kp::Tensor>& inA,
794
- const std::shared_ptr<kp::Tensor>& inB,
795
- const std::shared_ptr<kp::Tensor>& out,
796
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
797
- int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
798
- float scale
799
- ) {
800
- const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
801
- kp::shader_data::op_softmax_comp_spv_len);
802
-
803
- struct PushConstants {
804
- uint32_t inAOff, inBOff, outOff;
805
- int32_t ne00, ne01, ne02;
806
- float scale;
807
- int32_t mask;
808
- } pushConsts {
809
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
810
- ne00, ne01, ne02,
811
- scale,
812
- bool(inB)
813
- };
814
-
815
- auto & inB_ = inB ? inB : inA;
816
-
817
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
818
- if (!komputeManager()->hasAlgorithm(__func__)) {
819
- // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
820
- const uint32_t local_x = 32;
821
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
822
- } else {
823
- s_algo = komputeManager()->getAlgorithm(__func__);
824
- s_algo->setTensors({inA, inB_, out});
825
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
826
- s_algo->setPushConstants<PushConstants>({pushConsts});
827
- s_algo->updateDescriptors(s_kompute_context->pool.get());
828
- }
829
- seq.record<kp::OpAlgoDispatch>(s_algo);
830
- }
831
-
832
- static void ggml_vk_norm_(
833
- const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
834
- const std::shared_ptr<kp::Tensor>& in,
835
- const std::shared_ptr<kp::Tensor>& out,
836
- uint32_t inOff, uint32_t outOff,
837
- int32_t ne00, int32_t nb01,
838
- int32_t nrows, float epsilon
839
- ) {
840
- GGML_ASSERT(nb01%sizeof(float) == 0);
841
- GGML_ASSERT(ne00%sizeof(float) == 0);
842
-
843
- struct PushConstants {
844
- uint32_t inOff, outOff;
845
- uint32_t ne00, nb01;
846
- float eps;
847
- } pushConsts {
848
- safe_divide(inOff, 4), safe_divide(outOff, 4),
849
- (uint32_t)ne00, (uint32_t)nb01, epsilon
850
- };
851
-
852
- auto name = std::string(__func__) + "_" + suffix;
853
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
854
- if (!komputeManager()->hasAlgorithm(name)) {
855
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
856
- } else {
857
- s_algo = komputeManager()->getAlgorithm(name);
858
- s_algo->setTensors({in, out});
859
- s_algo->setWorkgroup({(uint32_t)nrows});
860
- s_algo->setPushConstants<PushConstants>({pushConsts});
861
- s_algo->updateDescriptors(s_kompute_context->pool.get());
862
- }
863
- seq.record<kp::OpAlgoDispatch>(s_algo);
864
- }
865
-
866
- template <typename... Args>
867
- static void ggml_vk_norm(Args&&... args) {
868
- const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
869
- kp::shader_data::op_norm_comp_spv_len);
870
-
871
- ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
872
- }
873
-
874
- template <typename... Args>
875
- static void ggml_vk_rms_norm(Args&&... args) {
876
- const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
877
- kp::shader_data::op_rmsnorm_comp_spv_len);
878
-
879
- ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
880
- }
881
-
882
- static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
883
- const std::shared_ptr<kp::Tensor>& in,
884
- const std::shared_ptr<kp::Tensor>& out,
885
- uint32_t inOff, uint32_t outOff,
886
- uint32_t n_past,
887
- int32_t ne00, int32_t ne01, int32_t ne02) {
888
- const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
889
- kp::shader_data::op_diagmask_comp_spv_len);
890
-
891
- struct PushConstants {
892
- uint32_t inOff, outOff;
893
- uint32_t n_past;
894
- int32_t ne00, ne01;
895
- } pushConsts {
896
- safe_divide(inOff, 4), safe_divide(outOff, 4),
897
- n_past,
898
- ne00, ne01
899
- };
900
-
901
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
902
- if (!komputeManager()->hasAlgorithm(__func__))
903
- s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
904
- else {
905
- s_algo = komputeManager()->getAlgorithm(__func__);
906
- s_algo->setTensors({in, out});
907
- s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
908
- s_algo->setPushConstants<PushConstants>({pushConsts});
909
- s_algo->updateDescriptors(s_kompute_context->pool.get());
910
- }
911
- seq.record<kp::OpAlgoDispatch>(s_algo);
912
- }
913
-
914
- static void ggml_vk_mul_mat_f16(
915
- kp::Sequence& seq,
916
- const std::shared_ptr<kp::Tensor>& inA,
917
- const std::shared_ptr<kp::Tensor>& inB,
918
- const std::shared_ptr<kp::Tensor>& out,
919
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
920
- int32_t ne00, int32_t ne01, int32_t ne02,
921
- uint32_t nb00, uint32_t nb01, uint32_t nb02,
922
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
923
- uint32_t nb10, uint32_t nb11, uint32_t nb12,
924
- int32_t ne0, int32_t ne1,
925
- uint32_t r2, uint32_t r3
926
- ) {
927
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
928
- kp::shader_data::op_mul_mat_f16_comp_spv_len);
929
-
930
- struct PushConstants {
931
- uint32_t inAOff, inBOff, outOff;
932
- int32_t ne00, ne01, ne02;
933
- uint32_t nb00, nb01, nb02;
934
- int32_t ne10, ne11, ne12;
935
- uint32_t nb10, nb11, nb12;
936
- int32_t ne0, ne1;
937
- uint32_t r2, r3;
938
- } pushConsts {
939
- safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
940
- ne00, ne01, ne02,
941
- nb00, nb01, nb02,
942
- ne10, ne11, ne12,
943
- nb10, nb11, nb12,
944
- ne0, ne1,
945
- r2, r3
946
- };
947
-
948
- const unsigned ny = unsigned((ne11 + 4 - 1)/4);
949
-
950
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
951
- if (!komputeManager()->hasAlgorithm(__func__)) {
952
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
953
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
954
- } else {
955
- s_algo = komputeManager()->getAlgorithm(__func__);
956
- s_algo->setTensors({inA, inB, out});
957
- s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
958
- s_algo->setPushConstants<PushConstants>({pushConsts});
959
- s_algo->updateDescriptors(s_kompute_context->pool.get());
960
- }
961
- seq.record<kp::OpAlgoDispatch>(s_algo);
962
- }
963
-
964
- static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
965
- const std::shared_ptr<kp::Tensor>& inA,
966
- const std::shared_ptr<kp::Tensor>& inB,
967
- const std::shared_ptr<kp::Tensor>& out,
968
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
969
- int32_t ne00, int32_t ne01, int32_t ne02,
970
- uint32_t nb01, uint32_t nb02,
971
- int32_t ne11, int32_t ne12,
972
- uint32_t nb11, uint32_t nb12,
973
- uint32_t nb1, uint32_t nb2) {
974
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
975
- kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
976
-
977
- struct PushConstants {
978
- uint32_t inAOff, inBOff, outOff;
979
- int32_t ne00, ne01, ne02, ne11, ne12;
980
- uint32_t nb01, nb02;
981
- uint32_t nb11, nb12;
982
- uint32_t nb1, nb2;
983
- } pushConsts {
984
- safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
985
- ne00, ne01, ne02, ne11, ne12,
986
- nb01, nb02, nb11, nb12,
987
- nb1, nb2
988
- };
989
-
990
- const uint32_t local_x = ggml_vk_current_device().subgroupSize;
991
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
992
- if (!komputeManager()->hasAlgorithm(__func__)) {
993
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
994
- {inA, inB, out}, spirv,
995
- {unsigned(ne01),
996
- unsigned(ne11),
997
- unsigned(std::max(ne12, ne02))
998
- },
999
- {local_x},
1000
- {pushConsts});
1001
- } else {
1002
- s_algo = komputeManager()->getAlgorithm(__func__);
1003
- s_algo->setTensors({inA, inB, out});
1004
- s_algo->setWorkgroup({unsigned(ne01),
1005
- unsigned(ne11),
1006
- unsigned(std::max(ne12, ne02)),
1007
- });
1008
- s_algo->setPushConstants<PushConstants>({pushConsts});
1009
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1010
- }
1011
- seq.record<kp::OpAlgoDispatch>(s_algo);
1012
- }
1013
-
1014
- static void ggml_vk_mul_mat_impl(
1015
- const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
1016
- const std::shared_ptr<kp::Tensor>& inA,
1017
- const std::shared_ptr<kp::Tensor>& inB,
1018
- const std::shared_ptr<kp::Tensor>& out,
1019
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1020
- int32_t ne00, int32_t ne01, int32_t ne02,
1021
- int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1022
- int32_t ne0, int32_t ne1,
1023
- uint32_t r2, uint32_t r3
1024
- ) {
1025
- struct PushConstants {
1026
- uint32_t inAOff, inBOff, outOff;
1027
- int32_t ne00, ne01, ne02;
1028
- int32_t ne10, ne12;
1029
- int32_t ne0, ne1;
1030
- uint32_t r2, r3;
1031
- } pushConsts {
1032
- safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1033
- ne00, ne01, ne02,
1034
- ne10, ne12,
1035
- ne0, ne1,
1036
- r2, r3
1037
- };
1038
-
1039
- auto name = std::string(__func__) + "_" + suffix;
1040
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1041
- if (!komputeManager()->hasAlgorithm(name)) {
1042
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1043
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1044
- } else {
1045
- s_algo = komputeManager()->getAlgorithm(name);
1046
- s_algo->setTensors({inA, inB, out});
1047
- s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
1048
- s_algo->setPushConstants<PushConstants>({pushConsts});
1049
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1050
- }
1051
- seq.record<kp::OpAlgoDispatch>(s_algo);
1052
- }
1053
-
1054
- template <typename... Args>
1055
- static void ggml_vk_mul_mat_q4_0(Args&&... args) {
1056
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
1057
- kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
1058
-
1059
- ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1060
- }
1061
-
1062
- template <typename... Args>
1063
- static void ggml_vk_mul_mat_q4_1(Args&&... args) {
1064
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
1065
- kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
1066
-
1067
- ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1068
- }
1069
-
1070
- template <typename... Args>
1071
- static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1072
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
1073
- kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
1074
-
1075
- ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1076
- }
1077
-
1078
- static void ggml_vk_mul_mat_q6_k(
1079
- kp::Sequence& seq,
1080
- const std::shared_ptr<kp::Tensor>& inA,
1081
- const std::shared_ptr<kp::Tensor>& inB,
1082
- const std::shared_ptr<kp::Tensor>& out,
1083
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1084
- int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1085
- int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1086
- ) {
1087
- const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1088
- kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1089
-
1090
- struct PushConstants {
1091
- uint32_t inAOff, inBOff, outOff;
1092
- int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1093
- } pushConsts {
1094
- inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1095
- ne00, ne10, ne0, ne1, ne01, ne12/ne02
1096
- };
1097
-
1098
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1099
- if (!komputeManager()->hasAlgorithm(__func__)) {
1100
- const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1101
- s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1102
- } else {
1103
- s_algo = komputeManager()->getAlgorithm(__func__);
1104
- s_algo->setTensors({inA, inB, out});
1105
- s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1106
- s_algo->setPushConstants<PushConstants>({pushConsts});
1107
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1108
- }
1109
- seq.record<kp::OpAlgoDispatch>(s_algo);
1110
- }
1111
-
1112
- static void ggml_vk_get_rows(
1113
- const std::vector<uint32_t>& spirv,
1114
- const char * suffix,
1115
- unsigned element_size, unsigned qk,
1116
- kp::Sequence& seq,
1117
- const std::shared_ptr<kp::Tensor>& inA,
1118
- const std::shared_ptr<kp::Tensor>& inB,
1119
- const std::shared_ptr<kp::Tensor>& out,
1120
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1121
- int32_t ne00, int32_t nb01, int32_t nb1,
1122
- uint32_t size
1123
- ) {
1124
- GGML_ASSERT(nb01%element_size == 0);
1125
- GGML_ASSERT(nb1%sizeof(float) == 0);
1126
- if (qk) GGML_ASSERT(ne00%qk == 0);
1127
-
1128
- struct PushConstants {
1129
- uint32_t inAOff, inBOff, outOff;
1130
- int32_t ne00, nb01, nb1;
1131
- } pushConsts {
1132
- safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1133
- ne00, nb01, nb1
1134
- };
1135
-
1136
- auto name = std::string(__func__) + "_" + suffix;
1137
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1138
- if (!komputeManager()->hasAlgorithm(name)) {
1139
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
1140
- } else {
1141
- s_algo = komputeManager()->getAlgorithm(name);
1142
- s_algo->setTensors({inA, inB, out});
1143
- s_algo->setWorkgroup({size});
1144
- s_algo->setPushConstants<PushConstants>({pushConsts});
1145
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1146
- }
1147
- seq.record<kp::OpAlgoDispatch>(s_algo);
1148
- }
1149
-
1150
- template <typename... Args>
1151
- static void ggml_vk_get_rows_f32(Args&&... args) {
1152
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1153
- kp::shader_data::op_getrows_f32_comp_spv_len);
1154
-
1155
- ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1156
- }
1157
-
1158
- template <typename... Args>
1159
- static void ggml_vk_get_rows_f16(Args&&... args) {
1160
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
1161
- kp::shader_data::op_getrows_f16_comp_spv_len);
1162
-
1163
- ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
1164
- }
1165
-
1166
- template <typename... Args>
1167
- static void ggml_vk_get_rows_q4_0(Args&&... args) {
1168
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
1169
- kp::shader_data::op_getrows_q4_0_comp_spv_len);
1170
-
1171
- ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
1172
- }
1173
-
1174
- template <typename... Args>
1175
- static void ggml_vk_get_rows_q4_1(Args&&... args) {
1176
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
1177
- kp::shader_data::op_getrows_q4_1_comp_spv_len);
1178
-
1179
- ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
1180
- }
1181
-
1182
- template <typename... Args>
1183
- static void ggml_vk_get_rows_q6_k(Args&&... args) {
1184
- const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
1185
- kp::shader_data::op_getrows_q6_k_comp_spv_len);
1186
- ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
1187
- }
1188
-
1189
- static void ggml_vk_rope(
1190
- kp::Sequence& seq,
1191
- const std::shared_ptr<kp::Tensor>& inA,
1192
- const std::shared_ptr<kp::Tensor>& inB,
1193
- const std::shared_ptr<kp::Tensor>& out,
1194
- uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1195
- ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1196
- float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1197
- int32_t ne01, int32_t ne02, int32_t ne03,
1198
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1199
- int32_t ne0,
1200
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1201
- ) {
1202
- GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1203
-
1204
- static const auto spirv_f16 = getSpirvShader(
1205
- kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1206
- );
1207
- static const auto spirv_f32 = getSpirvShader(
1208
- kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1209
- );
1210
-
1211
- int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
1212
-
1213
- GGML_ASSERT(nb03 % type_size == 0);
1214
- GGML_ASSERT(nb02 % type_size == 0);
1215
- GGML_ASSERT(nb01 % type_size == 0);
1216
- GGML_ASSERT(nb00 % type_size == 0);
1217
- GGML_ASSERT(nb3 % type_size == 0);
1218
- GGML_ASSERT(nb2 % type_size == 0);
1219
- GGML_ASSERT(nb1 % type_size == 0);
1220
- GGML_ASSERT(nb0 % type_size == 0);
1221
-
1222
- struct PushConstants {
1223
- uint32_t inAOff, inBOff, outOff;
1224
- int32_t n_dims, mode, n_ctx_orig;
1225
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1226
- uint32_t nb00, nb01, nb02, nb03;
1227
- int32_t ne0;
1228
- uint32_t nb0, nb1, nb2, nb3;
1229
- } pushConsts {
1230
- safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1231
- n_dims, mode, n_ctx_orig,
1232
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1233
- nb00, nb01, nb02, nb03,
1234
- ne0,
1235
- nb0, nb1, nb2, nb3
1236
- };
1237
-
1238
- auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1239
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1240
- if (!komputeManager()->hasAlgorithm(name)) {
1241
- s_algo = komputeManager()->algorithm<float, PushConstants>(
1242
- name, s_kompute_context->pool.get(), {inA, inB, out},
1243
- src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1244
- {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1245
- );
1246
- } else {
1247
- s_algo = komputeManager()->getAlgorithm(name);
1248
- s_algo->setTensors({inA, inB, out});
1249
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1250
- s_algo->setPushConstants<PushConstants>({pushConsts});
1251
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1252
- }
1253
- seq.record<kp::OpAlgoDispatch>(s_algo);
1254
- }
1255
-
1256
- static void ggml_vk_cpy(
1257
- const std::vector<uint32_t>& spirv,
1258
- uint32_t in_element_size, uint32_t out_element_size,
1259
- kp::Sequence& seq,
1260
- const std::shared_ptr<kp::Tensor>& in,
1261
- const std::shared_ptr<kp::Tensor>& out,
1262
- uint32_t inOff, uint32_t outOff,
1263
- int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
1264
- uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1265
- int32_t ne0, int32_t ne1, int32_t ne2,
1266
- uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1267
- ) {
1268
- struct PushConstants {
1269
- uint32_t inOff, outOff;
1270
- int32_t ne00, ne01, ne02;
1271
- uint32_t nb00, nb01, nb02, nb03;
1272
- int32_t ne0, ne1, ne2;
1273
- uint32_t nb0, nb1, nb2, nb3;
1274
- } pushConsts {
1275
- safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
1276
- ne00, ne01, ne02,
1277
- nb00, nb01, nb02, nb03,
1278
- ne0, ne1, ne2,
1279
- nb0, nb1, nb2, nb3
1280
- };
1281
-
1282
- std::string name = std::string(__func__)
1283
- + "_i_" + std::to_string(in_element_size)
1284
- + "_o_" + std::to_string(out_element_size);
1285
- std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1286
- if (!komputeManager()->hasAlgorithm(name))
1287
- s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
1288
- else {
1289
- s_algo = komputeManager()->getAlgorithm(name);
1290
- s_algo->setTensors({in, out});
1291
- s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1292
- s_algo->setPushConstants<PushConstants>({pushConsts});
1293
- s_algo->updateDescriptors(s_kompute_context->pool.get());
1294
- }
1295
- seq.record<kp::OpAlgoDispatch>(s_algo);
1296
- }
1297
-
1298
- template <typename... Args>
1299
- static void ggml_vk_cpy_f32_f16(Args&&... args) {
1300
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
1301
- kp::shader_data::op_cpy_f32_f16_comp_spv_len);
1302
- ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
1303
- }
1304
-
1305
- template <typename... Args>
1306
- static void ggml_vk_cpy_f32_f32(Args&&... args) {
1307
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
1308
- kp::shader_data::op_cpy_f32_f32_comp_spv_len);
1309
- ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
1310
- }
1311
-
1312
- template <typename... Args>
1313
- static void ggml_vk_cpy_f16_f16(Args&&... args) {
1314
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
1315
- kp::shader_data::op_cpy_f16_f16_comp_spv_len);
1316
- ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
1317
- }
1318
-
1319
- template <typename... Args>
1320
- static void ggml_vk_cpy_f16_f32(Args&&... args) {
1321
- const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
1322
- kp::shader_data::op_cpy_f16_f32_comp_spv_len);
1323
- ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1324
- }
1325
-
1326
- static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1327
- switch (op->type) {
1328
- case GGML_TYPE_F16:
1329
- case GGML_TYPE_F32:
1330
- case GGML_TYPE_Q4_0:
1331
- case GGML_TYPE_Q4_1:
1332
- break;
1333
- default:
1334
- return false;
1335
- }
1336
-
1337
- switch (op->op) {
1338
- case GGML_OP_UNARY:
1339
- switch (ggml_get_unary_op(op)) {
1340
- case GGML_UNARY_OP_RELU:
1341
- case GGML_UNARY_OP_GELU:
1342
- case GGML_UNARY_OP_SILU:
1343
- return ggml_is_contiguous(op->src[0]);
1344
- default:
1345
- ;
1346
- }
1347
- break;
1348
- case GGML_OP_NONE:
1349
- case GGML_OP_RESHAPE:
1350
- case GGML_OP_VIEW:
1351
- case GGML_OP_TRANSPOSE:
1352
- case GGML_OP_PERMUTE:
1353
- case GGML_OP_ADD:
1354
- case GGML_OP_MUL:
1355
- case GGML_OP_SCALE:
1356
- case GGML_OP_SOFT_MAX:
1357
- case GGML_OP_RMS_NORM:
1358
- case GGML_OP_NORM:
1359
- case GGML_OP_ROPE:
1360
- return true;
1361
- case GGML_OP_DUP:
1362
- case GGML_OP_CPY:
1363
- case GGML_OP_CONT:
1364
- switch (op->src[0]->type) {
1365
- case GGML_TYPE_F32:
1366
- case GGML_TYPE_F16:
1367
- break;
1368
- default:
1369
- return false;
1370
- }
1371
- switch (op->type) {
1372
- case GGML_TYPE_F32:
1373
- case GGML_TYPE_F16:
1374
- break;
1375
- default:
1376
- return false;
1377
- }
1378
- return true;
1379
- case GGML_OP_DIAG_MASK_INF:
1380
- return op->ne[3] == 1;
1381
- case GGML_OP_GET_ROWS:
1382
- switch (op->src[0]->type) {
1383
- case GGML_TYPE_F32:
1384
- case GGML_TYPE_F16:
1385
- case GGML_TYPE_Q4_0:
1386
- case GGML_TYPE_Q4_1:
1387
- case GGML_TYPE_Q6_K:
1388
- return op->ne[2] == 1 && op->ne[3] == 1;
1389
- default:
1390
- ;
1391
- }
1392
- return false;
1393
- case GGML_OP_MUL_MAT:
1394
- if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
1395
- return false;
1396
-
1397
- switch (op->src[0]->type) {
1398
- case GGML_TYPE_F32:
1399
- case GGML_TYPE_Q6_K:
1400
- return op->ne[3] == 1;
1401
- case GGML_TYPE_F16:
1402
- case GGML_TYPE_Q8_0:
1403
- case GGML_TYPE_Q4_0:
1404
- case GGML_TYPE_Q4_1:
1405
- return true;
1406
- default:
1407
- ;
1408
- }
1409
- default:
1410
- ;
1411
- }
1412
- return false;
1413
- }
1414
-
1415
- static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
1416
- const int n_seq = 8;
1417
-
1418
- // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
1419
- // it to the size of the graph, but I think it can be made smaller?
1420
- ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
1421
-
1422
- std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
1423
-
1424
- for (auto& sequence : sequences) {
1425
- sequence = komputeManager()->sequence();
1426
- }
1427
- for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
1428
- const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
1429
-
1430
- auto& seq = *sequences[seq_idx];
1431
-
1432
- const int node_start = (seq_idx + 0) * n_nodes_per_seq;
1433
- const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
1434
-
1435
- bool any_commands_recorded = false;
1436
-
1437
- for (int i = node_start; i < node_end; ++i) {
1438
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1439
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1440
- struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1441
- struct ggml_tensor * dst = gf->nodes[i];
1442
- GGML_ASSERT(dst->data != nullptr);
1443
-
1444
- if (ggml_is_empty(dst)) {
1445
- continue;
1446
- }
1447
-
1448
- switch (dst->op) {
1449
- case GGML_OP_NONE:
1450
- case GGML_OP_RESHAPE:
1451
- case GGML_OP_VIEW:
1452
- case GGML_OP_TRANSPOSE:
1453
- case GGML_OP_PERMUTE:
1454
- continue; // noop -> next node
1455
- default:
1456
- break;
1457
- }
1458
-
1459
- any_commands_recorded = true;
1460
-
1461
- if (!ggml_vk_supports_op(dst)) {
1462
- fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1463
- GGML_ASSERT(!"unsupported op");
1464
- }
1465
-
1466
- const int32_t ne00 = src0 ? src0->ne[0] : 0;
1467
- const int32_t ne01 = src0 ? src0->ne[1] : 0;
1468
- const int32_t ne02 = src0 ? src0->ne[2] : 0;
1469
- const int32_t ne03 = src0 ? src0->ne[3] : 0;
1470
-
1471
- const uint32_t nb00 = src0 ? src0->nb[0] : 0;
1472
- const uint32_t nb01 = src0 ? src0->nb[1] : 0;
1473
- const uint32_t nb02 = src0 ? src0->nb[2] : 0;
1474
- const uint32_t nb03 = src0 ? src0->nb[3] : 0;
1475
-
1476
- const int32_t ne10 = src1 ? src1->ne[0] : 0;
1477
- const int32_t ne11 = src1 ? src1->ne[1] : 0;
1478
- const int32_t ne12 = src1 ? src1->ne[2] : 0;
1479
- const int32_t ne13 = src1 ? src1->ne[3] : 0;
1480
-
1481
- const uint32_t nb10 = src1 ? src1->nb[0] : 0;
1482
- const uint32_t nb11 = src1 ? src1->nb[1] : 0;
1483
- const uint32_t nb12 = src1 ? src1->nb[2] : 0;
1484
- const uint32_t nb13 = src1 ? src1->nb[3] : 0;
1485
-
1486
- const int32_t ne0 = dst ? dst->ne[0] : 0;
1487
- const int32_t ne1 = dst ? dst->ne[1] : 0;
1488
- const int32_t ne2 = dst ? dst->ne[2] : 0;
1489
- // const int32_t ne3 = dst ? dst->ne[3] : 0;
1490
-
1491
- const uint32_t nb0 = dst ? dst->nb[0] : 0;
1492
- const uint32_t nb1 = dst ? dst->nb[1] : 0;
1493
- const uint32_t nb2 = dst ? dst->nb[2] : 0;
1494
- const uint32_t nb3 = dst ? dst->nb[3] : 0;
1495
-
1496
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1497
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1498
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1499
-
1500
- const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1501
- uint32_t off_src0 = 0;
1502
- uint32_t off_src1 = 0;
1503
- uint32_t off_dst = 0;
1504
- const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1505
- const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1506
- const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
1507
-
1508
- switch (dst->op) {
1509
- case GGML_OP_ADD:
1510
- {
1511
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1512
- // src1 is a row
1513
- ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
1514
- } else {
1515
- ggml_vk_add(
1516
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1517
- ne00, ne01, ne02, ne03,
1518
- nb00, nb01, nb02, nb03,
1519
- ne10, ne11, ne12, ne13,
1520
- nb10, nb11, nb12, nb13,
1521
- ne0,
1522
- nb0, nb1, nb2, nb3
1523
- );
1524
- }
1525
- } break;
1526
- case GGML_OP_MUL:
1527
- {
1528
- ggml_vk_mul(
1529
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1530
- ne00, ne01, ne02, ne03,
1531
- nb00, nb01, nb02, nb03,
1532
- ne10, ne11, ne12, ne13,
1533
- nb10, nb11, nb12, nb13,
1534
- ne0,
1535
- nb0, nb1, nb2, nb3
1536
- );
1537
- } break;
1538
- case GGML_OP_SCALE:
1539
- {
1540
- float scale; memcpy(&scale, dst->op_params, sizeof(float));
1541
-
1542
- ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
1543
- } break;
1544
- case GGML_OP_UNARY:
1545
- {
1546
- int64_t n = ggml_nelements(dst);
1547
- GGML_ASSERT(n % 4 == 0);
1548
- switch (ggml_get_unary_op(gf->nodes[i])) {
1549
- case GGML_UNARY_OP_SILU:
1550
- {
1551
- ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1552
- } break;
1553
- case GGML_UNARY_OP_RELU:
1554
- {
1555
- ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1556
- } break;
1557
- case GGML_UNARY_OP_GELU:
1558
- {
1559
- GGML_ASSERT(n % 8 == 0);
1560
- ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
1561
- } break;
1562
- default:
1563
- {
1564
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1565
- GGML_ASSERT(false);
1566
- }
1567
- }
1568
- } break;
1569
- case GGML_OP_SOFT_MAX:
1570
- {
1571
- float scale;
1572
- float max_bias;
1573
-
1574
- memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1575
- memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1576
-
1577
- #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1578
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1579
- GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1580
-
1581
- #pragma message("TODO: add ALiBi support")
1582
- #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1583
- GGML_ASSERT(max_bias == 0.0f);
1584
-
1585
- ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1586
- } break;
1587
- case GGML_OP_DIAG_MASK_INF:
1588
- {
1589
- const int n_past = ((int32_t *)(dst->op_params))[0];
1590
- ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
1591
- } break;
1592
- case GGML_OP_NORM:
1593
- {
1594
- float eps;
1595
- memcpy(&eps, dst->op_params, sizeof(float));
1596
- ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1597
- } break;
1598
- case GGML_OP_RMS_NORM:
1599
- {
1600
- GGML_ASSERT(ne00 % 4 == 0);
1601
-
1602
- float eps;
1603
- memcpy(&eps, dst->op_params, sizeof(float));
1604
- ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1605
- } break;
1606
- case GGML_OP_MUL_MAT:
1607
- {
1608
- GGML_ASSERT(ne00 == ne10);
1609
-
1610
- GGML_ASSERT(ne12 % ne02 == 0);
1611
- GGML_ASSERT(ne13 % ne03 == 0);
1612
-
1613
- const uint32_t r2 = ne12/ne02;
1614
- const uint32_t r3 = ne13/ne03;
1615
-
1616
- if (src1t != GGML_TYPE_F32) {
1617
- fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1618
- goto not_implemented;
1619
- }
1620
-
1621
- if (ggml_is_transposed(src0) ||
1622
- ggml_is_transposed(src1)) {
1623
- fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1624
- goto not_implemented;
1625
- }
1626
-
1627
- switch (src0t) {
1628
- case GGML_TYPE_F32:
1629
- ggml_vk_mul_mat_mat_f32(
1630
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1631
- ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
1632
- );
1633
- break;
1634
- case GGML_TYPE_F16:
1635
- ggml_vk_mul_mat_f16(
1636
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1637
- ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1638
- ne0, ne1, r2, r3
1639
- );
1640
- break;
1641
- case GGML_TYPE_Q8_0:
1642
- ggml_vk_mul_mat_q8_0(
1643
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1644
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1645
- );
1646
- break;
1647
- case GGML_TYPE_Q4_0:
1648
- ggml_vk_mul_mat_q4_0(
1649
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1650
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1651
- );
1652
- break;
1653
- case GGML_TYPE_Q4_1:
1654
- ggml_vk_mul_mat_q4_1(
1655
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1656
- ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1657
- );
1658
- break;
1659
- case GGML_TYPE_Q6_K:
1660
- ggml_vk_mul_mat_q6_k(
1661
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1662
- ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1663
- );
1664
- break;
1665
- default: {
1666
- fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1667
- goto not_implemented;
1668
- }
1669
- }
1670
-
1671
- } break;
1672
- case GGML_OP_GET_ROWS:
1673
- {
1674
- if (src0t == GGML_TYPE_F32) {
1675
- ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1676
- } else if (src0t == GGML_TYPE_F16) {
1677
- ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1678
- } else if (src0t == GGML_TYPE_Q4_0) {
1679
- ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1680
- } else if (src0t == GGML_TYPE_Q4_1) {
1681
- ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1682
- } else if (src0t == GGML_TYPE_Q6_K) {
1683
- ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1684
- } else {
1685
- fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
1686
- goto not_implemented;
1687
- }
1688
- } break;
1689
- case GGML_OP_ROPE:
1690
- {
1691
- #pragma message("TODO: implement phi3 frequency factors support")
1692
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1693
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1694
-
1695
- #pragma message("TODO: update rope NORM mode to match NEOX mode")
1696
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1697
-
1698
- GGML_ASSERT(ne10 == ne02);
1699
- GGML_ASSERT(src0t == dstt);
1700
- // const int n_past = ((int32_t *) dst->op_params)[0];
1701
- const int n_dims = ((int32_t *) dst->op_params)[1];
1702
- const int mode = ((int32_t *) dst->op_params)[2];
1703
- // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1704
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1705
-
1706
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1707
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1708
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1709
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1710
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1711
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1712
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1713
- ggml_vk_rope(
1714
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1715
- freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1716
- ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1717
- );
1718
- } break;
1719
- case GGML_OP_DUP:
1720
- case GGML_OP_CPY:
1721
- case GGML_OP_CONT:
1722
- {
1723
- switch (src0t) {
1724
- case GGML_TYPE_F32:
1725
- {
1726
- switch (dstt) {
1727
- case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1728
- case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1729
- default: goto not_implemented;
1730
- }
1731
- } break;
1732
- case GGML_TYPE_F16:
1733
- {
1734
- switch (dstt) {
1735
- case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1736
- case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1737
- default: goto not_implemented;
1738
- } break;
1739
- default: goto not_implemented;
1740
- }
1741
- }
1742
- } break;
1743
- default: goto not_implemented;
1744
- }
1745
- continue;
1746
- not_implemented: {}
1747
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1748
- //GGML_ASSERT(false);
1749
- }
1750
-
1751
- // Evaluate sequence
1752
- if (any_commands_recorded) {
1753
- seq.evalAsync();
1754
- }
1755
- }
1756
-
1757
- // Wait for all sequences to finish
1758
- for (auto& sequence : sequences) {
1759
- if (sequence->isRunning())
1760
- sequence->evalAwait();
1761
- }
1762
-
1763
- ggml_vk_free_descriptor_pool(ctx);
1764
- }
1765
-
1766
- template<>
1767
- kp::Tensor::TensorDataTypes
1768
- kp::TensorT<half>::dataType()
1769
- {
1770
- return TensorDataTypes::eFloat;
1771
- }
1772
-
1773
- template<>
1774
- kp::Tensor::TensorDataTypes
1775
- kp::TensorT<uint8_t>::dataType()
1776
- {
1777
- return TensorDataTypes::eUnsignedInt;
1778
- }
1779
-
1780
- ////////////////////////////////////////////////////////////////////////////////
1781
-
1782
- // backend interface
1783
-
1784
- struct ggml_backend_kompute_buffer_type_context {
1785
- int device;
1786
- int device_ref = 0;
1787
- uint64_t buffer_alignment;
1788
- uint64_t max_alloc;
1789
- std::string name;
1790
-
1791
- ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
1792
- : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
1793
- };
1794
-
1795
- static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
1796
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1797
-
1798
- if (!ctx->device_ref) {
1799
- komputeManager()->initializeDevice(
1800
- ctx->device, {}, {
1801
- "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
1802
- "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
1803
- }
1804
- );
1805
- }
1806
-
1807
- assert(ggml_vk_has_device());
1808
- ctx->device_ref++;
1809
- }
1810
-
1811
- static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1812
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1813
-
1814
- assert(ctx->device_ref > 0);
1815
-
1816
- ctx->device_ref--;
1817
-
1818
- if (!ctx->device_ref) {
1819
- komputeManager.destroy();
1820
- }
1821
- }
1822
-
1823
- static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
1824
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
1825
- return ctx->name.c_str();
1826
- }
1827
-
1828
- static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1829
- auto * memory = (ggml_vk_memory *)buffer->context;
1830
- if (ggml_vk_has_device()) {
1831
- ggml_vk_free_memory(*memory);
1832
- }
1833
- delete memory;
1834
- }
1835
-
1836
- static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
1837
- return ((ggml_vk_memory *)buffer->context)->data;
1838
- }
1839
-
1840
- static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1841
- GGML_UNUSED(buffer);
1842
-
1843
- const auto res = ggml_vk_get_tensor(tensor);
1844
- GGML_ASSERT(res);
1845
-
1846
- memcpy((char *)tensor->data + offset, data, size);
1847
-
1848
- komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
1849
- }
1850
-
1851
- static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1852
- GGML_UNUSED(buffer);
1853
-
1854
- const auto res = ggml_vk_get_tensor(tensor);
1855
- GGML_ASSERT(res);
1856
-
1857
- komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
1858
-
1859
- memcpy(data, (const char *)tensor->data + offset, size);
1860
- }
1861
-
1862
- static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1863
- auto * memory = (ggml_vk_memory *)buffer->context;
1864
- memset(memory->data, value, buffer->size);
1865
-
1866
- if (memory->stagingBuffer)
1867
- komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
1868
- }
1869
-
1870
- static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1871
- /* .get_name = */ ggml_backend_kompute_buffer_get_name,
1872
- /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
1873
- /* .get_base = */ ggml_backend_kompute_buffer_get_base,
1874
- /* .init_tensor = */ NULL,
1875
- /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
1876
- /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
1877
- /* .cpy_tensor = */ NULL,
1878
- /* .clear = */ ggml_backend_kompute_buffer_clear,
1879
- /* .reset = */ NULL,
1880
- };
1881
-
1882
- // default buffer type
1883
-
1884
- static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1885
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1886
- return ctx->name.c_str();
1887
- }
1888
-
1889
- static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1890
- ggml_backend_kompute_device_ref(buft);
1891
- auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
1892
- return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
1893
- }
1894
-
1895
- static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1896
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1897
- return ctx->buffer_alignment;
1898
- }
1899
-
1900
- static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1901
- auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1902
- return ctx->max_alloc;
1903
- }
1904
-
1905
- static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1906
- /* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
1907
- /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
1908
- /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
1909
- /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
1910
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1911
- /* .is_host = */ NULL,
1912
- };
1913
-
1914
- ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1915
- static std::vector<ggml_backend_buffer_type> bufts = []() {
1916
- std::vector<ggml_backend_buffer_type> vec;
1917
- auto devices = ggml_vk_available_devices_internal(0);
1918
- vec.reserve(devices.size());
1919
-
1920
- for (const auto & dev : devices) {
1921
- vec.push_back({
1922
- /* .iface = */ ggml_backend_kompute_buffer_type_interface,
1923
- /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
1924
- });
1925
- }
1926
- return vec;
1927
- }();
1928
-
1929
- auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
1930
- return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
1931
- });
1932
- return it < bufts.end() ? &*it : nullptr;
1933
- }
1934
-
1935
- // backend
1936
-
1937
- static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
1938
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1939
- return ctx->name.c_str();
1940
- }
1941
-
1942
- static void ggml_backend_kompute_free(ggml_backend_t backend) {
1943
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1944
-
1945
- assert(ctx == s_kompute_context);
1946
- s_kompute_context = nullptr;
1947
- if (ctx != nullptr) {
1948
- delete ctx;
1949
- }
1950
-
1951
- delete backend;
1952
- }
1953
-
1954
- static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
1955
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1956
- return ggml_backend_kompute_buffer_type(ctx->device);
1957
- }
1958
-
1959
- static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1960
- auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1961
- ggml_vk_graph_compute(ctx, cgraph);
1962
- return GGML_STATUS_SUCCESS;
1963
- }
1964
-
1965
- static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1966
- GGML_UNUSED(backend);
1967
- return ggml_vk_supports_op(op);
1968
- }
1969
-
1970
- static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1971
- GGML_UNUSED(backend);
1972
- return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
1973
- }
1974
-
1975
- static struct ggml_backend_i kompute_backend_i = {
1976
- /* .get_name = */ ggml_backend_kompute_name,
1977
- /* .free = */ ggml_backend_kompute_free,
1978
- /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
1979
- /* .set_tensor_async = */ NULL,
1980
- /* .get_tensor_async = */ NULL,
1981
- /* .cpy_tensor_async = */ NULL,
1982
- /* .synchronize = */ NULL,
1983
- /* .graph_plan_create = */ NULL,
1984
- /* .graph_plan_free = */ NULL,
1985
- /* .graph_plan_update = */ NULL,
1986
- /* .graph_plan_compute = */ NULL,
1987
- /* .graph_compute = */ ggml_backend_kompute_graph_compute,
1988
- /* .supports_op = */ ggml_backend_kompute_supports_op,
1989
- /* .supports_buft = */ ggml_backend_kompute_supports_buft,
1990
- /* .offload_op = */ NULL,
1991
- /* .event_new = */ NULL,
1992
- /* .event_free = */ NULL,
1993
- /* .event_record = */ NULL,
1994
- /* .event_wait = */ NULL,
1995
- /* .event_synchronize = */ NULL,
1996
- };
1997
-
1998
- static ggml_guid_t ggml_backend_kompute_guid() {
1999
- static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
2000
- return &guid;
2001
- }
2002
-
2003
- ggml_backend_t ggml_backend_kompute_init(int device) {
2004
- GGML_ASSERT(s_kompute_context == nullptr);
2005
- s_kompute_context = new ggml_kompute_context(device);
2006
-
2007
- ggml_backend_t kompute_backend = new ggml_backend {
2008
- /* .guid = */ ggml_backend_kompute_guid(),
2009
- /* .interface = */ kompute_backend_i,
2010
- /* .context = */ s_kompute_context,
2011
- };
2012
-
2013
- return kompute_backend;
2014
- }
2015
-
2016
- bool ggml_backend_is_kompute(ggml_backend_t backend) {
2017
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2018
- }
2019
-
2020
- static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
2021
- GGML_UNUSED(params);
2022
- return ggml_backend_kompute_init(intptr_t(user_data));
2023
- }
2024
-
2025
- extern "C" int ggml_backend_kompute_reg_devices();
2026
-
2027
- int ggml_backend_kompute_reg_devices() {
2028
- auto devices = ggml_vk_available_devices_internal(0);
2029
- for (const auto & device : devices) {
2030
- ggml_backend_register(
2031
- ggml_kompute_format_name(device.index).c_str(),
2032
- ggml_backend_reg_kompute_init,
2033
- ggml_backend_kompute_buffer_type(device.index),
2034
- reinterpret_cast<void *>(intptr_t(device.index))
2035
- );
2036
- }
2037
- return devices.size();
2038
- }