llama_cpp 0.15.4 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,10 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-wmma-f16.cuh"
4
+
5
+ DECL_FATTN_WMMA_F16_CASE(64, 16, float);
6
+ DECL_FATTN_WMMA_F16_CASE(80, 16, float);
7
+ DECL_FATTN_WMMA_F16_CASE(96, 16, float);
8
+ DECL_FATTN_WMMA_F16_CASE(112, 16, float);
9
+ DECL_FATTN_WMMA_F16_CASE(128, 16, float);
10
+ DECL_FATTN_WMMA_F16_CASE(256, 16, float);
@@ -0,0 +1,9 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-wmma-f16.cuh"
4
+
5
+ DECL_FATTN_WMMA_F16_CASE(64, 32, float);
6
+ DECL_FATTN_WMMA_F16_CASE(80, 32, float);
7
+ DECL_FATTN_WMMA_F16_CASE(96, 32, float);
8
+ DECL_FATTN_WMMA_F16_CASE(112, 32, float);
9
+ DECL_FATTN_WMMA_F16_CASE(128, 32, float);
@@ -0,0 +1,10 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-wmma-f16.cuh"
4
+
5
+ DECL_FATTN_WMMA_F16_CASE(64, 16, half);
6
+ DECL_FATTN_WMMA_F16_CASE(80, 16, half);
7
+ DECL_FATTN_WMMA_F16_CASE(96, 16, half);
8
+ DECL_FATTN_WMMA_F16_CASE(112, 16, half);
9
+ DECL_FATTN_WMMA_F16_CASE(128, 16, half);
10
+ DECL_FATTN_WMMA_F16_CASE(256, 16, half);
@@ -0,0 +1,10 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-wmma-f16.cuh"
4
+
5
+ DECL_FATTN_WMMA_F16_CASE(64, 32, half);
6
+ DECL_FATTN_WMMA_F16_CASE(80, 32, half);
7
+ DECL_FATTN_WMMA_F16_CASE(96, 32, half);
8
+ DECL_FATTN_WMMA_F16_CASE(112, 32, half);
9
+ DECL_FATTN_WMMA_F16_CASE(128, 32, half);
10
+ DECL_FATTN_WMMA_F16_CASE(256, 32, half);
@@ -0,0 +1,8 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-wmma-f16.cuh"
4
+
5
+ DECL_FATTN_WMMA_F16_CASE(64, 8, half);
6
+ DECL_FATTN_WMMA_F16_CASE(96, 8, half);
7
+ DECL_FATTN_WMMA_F16_CASE(128, 8, half);
8
+ DECL_FATTN_WMMA_F16_CASE(256, 8, half);
@@ -0,0 +1,47 @@
1
+ #include "tsembd.cuh"
2
+
3
+ static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
4
+ // blockIDx.y: idx of timesteps->ne[0]
5
+ // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
6
+ int i = blockIdx.y;
7
+ int j = threadIdx.x + blockIdx.x * blockDim.x;
8
+ float * embed_data = (float *)((char *)dst + i*nb1);
9
+
10
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
11
+ embed_data[dim] = 0.f;
12
+ }
13
+
14
+ int half = dim / 2;
15
+ if (j >= half) {
16
+ return;
17
+ }
18
+
19
+ float timestep = timesteps[i];
20
+ float freq = (float)expf(-logf(max_period) * j / half);
21
+ float arg = timestep * freq;
22
+ embed_data[j] = cosf(arg);
23
+ embed_data[j + half] = sinf(arg);
24
+ }
25
+
26
+ static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
27
+ const int dim, const int max_period, cudaStream_t stream) {
28
+ int half_ceil = (dim + 1) / 2;
29
+ int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
30
+ dim3 gridDim(num_blocks, ne00, 1);
31
+ timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
32
+ }
33
+
34
+ void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35
+ const ggml_tensor * src0 = dst->src[0];
36
+ const float * src0_d = (const float *)src0->data;
37
+ float * dst_d = (float *)dst->data;
38
+ cudaStream_t stream = ctx.stream();
39
+
40
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
41
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
42
+
43
+ const int dim = dst->op_params[0];
44
+ const int max_period = dst->op_params[1];
45
+
46
+ timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
47
+ }
@@ -0,0 +1,266 @@
1
+ #include "unary.cuh"
2
+
3
+ static __global__ void gelu_f32(const float * x, float * dst, const int k) {
4
+ const float GELU_COEF_A = 0.044715f;
5
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
6
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
7
+
8
+ if (i >= k) {
9
+ return;
10
+ }
11
+
12
+ float xi = x[i];
13
+ dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
14
+ }
15
+
16
+ static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
17
+ const float GELU_QUICK_COEF = -1.702f;
18
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
19
+ if (i >= k) {
20
+ return;
21
+ }
22
+ dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
23
+ }
24
+
25
+ static __global__ void silu_f32(const float * x, float * dst, const int k) {
26
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
27
+
28
+ if (i >= k) {
29
+ return;
30
+ }
31
+ dst[i] = x[i] / (1.0f + expf(-x[i]));
32
+ }
33
+
34
+ static __global__ void tanh_f32(const float * x, float * dst, int k) {
35
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
36
+ if (i >= k) {
37
+ return;
38
+ }
39
+ dst[i] = tanhf(x[i]);
40
+ }
41
+
42
+ static __global__ void relu_f32(const float * x, float * dst, const int k) {
43
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
44
+
45
+ if (i >= k) {
46
+ return;
47
+ }
48
+ dst[i] = fmaxf(x[i], 0);
49
+ }
50
+
51
+ static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
52
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
53
+
54
+ if (i >= k) {
55
+ return;
56
+ }
57
+ dst[i] = 1.0f / (1.0f + expf(-x[i]));
58
+ }
59
+
60
+ static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
61
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
62
+
63
+ if (i >= k) {
64
+ return;
65
+ }
66
+ dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
67
+ }
68
+
69
+ static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
70
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
71
+
72
+ if (i >= k) {
73
+ return;
74
+ }
75
+ dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
76
+ }
77
+
78
+ static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
79
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
80
+ if (i >= k) {
81
+ return;
82
+ }
83
+ dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
84
+ }
85
+
86
+ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
87
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
88
+
89
+ if (i >= k) {
90
+ return;
91
+ }
92
+ dst[i] = x[i] * x[i];
93
+ }
94
+
95
+ static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
96
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
97
+ gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
98
+ }
99
+
100
+ static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
101
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
102
+ gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
103
+ }
104
+
105
+ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
106
+ const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
107
+ silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
108
+ }
109
+
110
+ static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
111
+ const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
112
+ tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
113
+ }
114
+
115
+ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
116
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
117
+ relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
118
+ }
119
+
120
+ static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
121
+ const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
122
+ sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
123
+ }
124
+
125
+ static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
126
+ const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
127
+ hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
128
+ }
129
+
130
+ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
131
+ const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
132
+ hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
133
+ }
134
+
135
+ static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
136
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
137
+ leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
138
+ }
139
+
140
+ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
141
+ const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
142
+ sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
143
+ }
144
+
145
+ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
146
+ const ggml_tensor * src0 = dst->src[0];
147
+ const float * src0_d = (const float *)src0->data;
148
+ float * dst_d = (float *)dst->data;
149
+ cudaStream_t stream = ctx.stream();
150
+
151
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
152
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
153
+
154
+ gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
155
+ }
156
+
157
+ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
158
+ const ggml_tensor * src0 = dst->src[0];
159
+ const float * src0_d = (const float *)src0->data;
160
+ float * dst_d = (float *)dst->data;
161
+ cudaStream_t stream = ctx.stream();
162
+
163
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
164
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
165
+
166
+ silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
167
+ }
168
+
169
+ void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
170
+ const ggml_tensor * src0 = dst->src[0];
171
+ const float * src0_d = (const float *)src0->data;
172
+ float * dst_d = (float *)dst->data;
173
+ cudaStream_t stream = ctx.stream();
174
+
175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
176
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
177
+
178
+ gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
179
+ }
180
+
181
+ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182
+ const ggml_tensor * src0 = dst->src[0];
183
+ const float * src0_d = (const float *)src0->data;
184
+ float * dst_d = (float *)dst->data;
185
+ cudaStream_t stream = ctx.stream();
186
+
187
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
188
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
189
+
190
+ tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
191
+ }
192
+
193
+ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
194
+ const ggml_tensor * src0 = dst->src[0];
195
+ const float * src0_d = (const float *)src0->data;
196
+ float * dst_d = (float *)dst->data;
197
+ cudaStream_t stream = ctx.stream();
198
+
199
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
200
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
201
+
202
+ relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
203
+ }
204
+
205
+ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206
+ const ggml_tensor * src0 = dst->src[0];
207
+ const float * src0_d = (const float *)src0->data;
208
+ float * dst_d = (float *)dst->data;
209
+ cudaStream_t stream = ctx.stream();
210
+
211
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
212
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
213
+
214
+ sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
215
+ }
216
+
217
+ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
218
+ const ggml_tensor * src0 = dst->src[0];
219
+ const float * src0_d = (const float *)src0->data;
220
+ float * dst_d = (float *)dst->data;
221
+ cudaStream_t stream = ctx.stream();
222
+
223
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
224
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
225
+
226
+ hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
227
+ }
228
+
229
+ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
+ const ggml_tensor * src0 = dst->src[0];
231
+ const float * src0_d = (const float *)src0->data;
232
+ float * dst_d = (float *)dst->data;
233
+ cudaStream_t stream = ctx.stream();
234
+
235
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
236
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
237
+
238
+ hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
239
+ }
240
+
241
+ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
242
+ const ggml_tensor * src0 = dst->src[0];
243
+ const float * src0_d = (const float *)src0->data;
244
+ float * dst_d = (float *)dst->data;
245
+ cudaStream_t stream = ctx.stream();
246
+
247
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
248
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
249
+
250
+ float negative_slope;
251
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
252
+
253
+ leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
254
+ }
255
+
256
+ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
257
+ const ggml_tensor * src0 = dst->src[0];
258
+ const float * src0_d = (const float *)src0->data;
259
+ float * dst_d = (float *)dst->data;
260
+ cudaStream_t stream = ctx.stream();
261
+
262
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
263
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
264
+
265
+ sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
266
+ }
@@ -0,0 +1,51 @@
1
+ #include "upscale.cuh"
2
+
3
+ static __global__ void upscale_f32(const float * x, float * dst,
4
+ const int nb00, const int nb01, const int nb02, const int nb03,
5
+ const int ne10, const int ne11, const int ne12, const int ne13,
6
+ const float sf0, const float sf1, const float sf2, const float sf3) {
7
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
8
+ if (index >= ne10 * ne11 * ne12 * ne13) {
9
+ return;
10
+ }
11
+
12
+ int i10 = index % ne10;
13
+ int i11 = (index / ne10) % ne11;
14
+ int i12 = (index / (ne10 * ne11)) % ne12;
15
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
16
+
17
+ int i00 = i10 / sf0;
18
+ int i01 = i11 / sf1;
19
+ int i02 = i12 / sf2;
20
+ int i03 = i13 / sf3;
21
+
22
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
23
+ }
24
+
25
+ static void upscale_f32_cuda(const float * x, float * dst,
26
+ const int nb00, const int nb01, const int nb02, const int nb03,
27
+ const int ne10, const int ne11, const int ne12, const int ne13,
28
+ const float sf0, const float sf1, const float sf2, const float sf3,
29
+ cudaStream_t stream) {
30
+ int dst_size = ne10 * ne11 * ne12 * ne13;
31
+ int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
32
+
33
+ upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
34
+ }
35
+
36
+ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37
+ const ggml_tensor * src0 = dst->src[0];
38
+ const float * src0_d = (const float *)src0->data;
39
+ float * dst_d = (float *)dst->data;
40
+ cudaStream_t stream = ctx.stream();
41
+
42
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
43
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
44
+
45
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
46
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
47
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
48
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
49
+
50
+ upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
51
+ }
@@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2702
2702
 
2703
2703
  if (cuda_graph_update_required) {
2704
2704
  // Extract nodes from graph
2705
- if (cuda_ctx->cuda_graph->num_nodes == 0) {
2706
- // First call with null argument gets number of nodes in graph
2707
- CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2708
- }
2705
+ // First call with null argument gets number of nodes in graph
2706
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2709
2707
  // Subsequent call with non-null argument gets nodes
2710
2708
  cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2711
2709
  cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
@@ -2905,10 +2903,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2905
2903
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2906
2904
  return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2907
2905
  #else
2908
- if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2906
+ if (op->src[0]->ne[0] == 128) {
2907
+ return true;
2908
+ }
2909
+ if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2909
2910
  return true;
2910
2911
  }
2911
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2912
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2913
+ op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2912
2914
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2913
2915
  default:
2914
2916
  return false;
@@ -22,6 +22,7 @@
22
22
  #include "shaderop_mul_mat_q4_1.h"
23
23
  #include "shaderop_mul_mat_q6_k.h"
24
24
  #include "shaderop_mul_mat_mat_f32.h"
25
+ #include "shaderop_getrows_f32.h"
25
26
  #include "shaderop_getrows_f16.h"
26
27
  #include "shaderop_getrows_q4_0.h"
27
28
  #include "shaderop_getrows_q4_1.h"
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
1146
1147
  seq.record<kp::OpAlgoDispatch>(s_algo);
1147
1148
  }
1148
1149
 
1150
+ template <typename... Args>
1151
+ static void ggml_vk_get_rows_f32(Args&&... args) {
1152
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1153
+ kp::shader_data::op_getrows_f32_comp_spv_len);
1154
+
1155
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1156
+ }
1157
+
1149
1158
  template <typename... Args>
1150
1159
  static void ggml_vk_get_rows_f16(Args&&... args) {
1151
1160
  const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
@@ -1183,7 +1192,7 @@ static void ggml_vk_rope(
1183
1192
  const std::shared_ptr<kp::Tensor>& inB,
1184
1193
  const std::shared_ptr<kp::Tensor>& out,
1185
1194
  uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1186
- ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
1195
+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1187
1196
  float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1188
1197
  int32_t ne01, int32_t ne02, int32_t ne03,
1189
1198
  uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
@@ -1212,14 +1221,14 @@ static void ggml_vk_rope(
1212
1221
 
1213
1222
  struct PushConstants {
1214
1223
  uint32_t inAOff, inBOff, outOff;
1215
- int32_t n_dims, mode, n_orig_ctx;
1224
+ int32_t n_dims, mode, n_ctx_orig;
1216
1225
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1217
1226
  uint32_t nb00, nb01, nb02, nb03;
1218
1227
  int32_t ne0;
1219
1228
  uint32_t nb0, nb1, nb2, nb3;
1220
1229
  } pushConsts {
1221
1230
  safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1222
- n_dims, mode, n_orig_ctx,
1231
+ n_dims, mode, n_ctx_orig,
1223
1232
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1224
1233
  nb00, nb01, nb02, nb03,
1225
1234
  ne0,
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1371
1380
  return op->ne[3] == 1;
1372
1381
  case GGML_OP_GET_ROWS:
1373
1382
  switch (op->src[0]->type) {
1383
+ case GGML_TYPE_F32:
1374
1384
  case GGML_TYPE_F16:
1375
1385
  case GGML_TYPE_Q4_0:
1376
1386
  case GGML_TYPE_Q4_1:
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1661
1671
  } break;
1662
1672
  case GGML_OP_GET_ROWS:
1663
1673
  {
1664
- if (src0t == GGML_TYPE_F16) {
1674
+ if (src0t == GGML_TYPE_F32) {
1675
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1676
+ } else if (src0t == GGML_TYPE_F16) {
1665
1677
  ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1666
1678
  } else if (src0t == GGML_TYPE_Q4_0) {
1667
1679
  ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
@@ -1680,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1680
1692
  #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
1681
1693
  GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1682
1694
 
1695
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
1696
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
1697
+
1683
1698
  GGML_ASSERT(ne10 == ne02);
1684
1699
  GGML_ASSERT(src0t == dstt);
1685
1700
  // const int n_past = ((int32_t *) dst->op_params)[0];
1686
1701
  const int n_dims = ((int32_t *) dst->op_params)[1];
1687
1702
  const int mode = ((int32_t *) dst->op_params)[2];
1688
1703
  // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1689
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1704
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1690
1705
 
1691
1706
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1692
1707
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1696,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1696
1711
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1697
1712
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1698
1713
  ggml_vk_rope(
1699
- seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1714
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1700
1715
  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1701
1716
  ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1702
1717
  );
@@ -1,7 +1,7 @@
1
1
  // An interface allowing to compute ggml_cgraph with Metal
2
2
  //
3
3
  // This is a fully functional interface that extends ggml with GPU support for Apple devices.
4
- // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
4
+ // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
5
5
  //
6
6
  // How it works?
7
7
  //
@@ -172,8 +172,10 @@ enum ggml_metal_kernel_type {
172
172
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173
173
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174
174
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175
- GGML_METAL_KERNEL_TYPE_ROPE_F32,
176
- GGML_METAL_KERNEL_TYPE_ROPE_F16,
175
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
177
179
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178
180
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179
181
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
626
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
631
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633
637
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -779,6 +783,12 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
779
783
  case GGML_OP_LEAKY_RELU:
780
784
  return true;
781
785
  case GGML_OP_FLASH_ATTN_EXT:
786
+ if (op->src[1]->type != GGML_TYPE_F16) {
787
+ return false;
788
+ }
789
+ if (op->src[2]->type != GGML_TYPE_F16) {
790
+ return false;
791
+ }
782
792
  if (op->src[0]->ne[0] == 256) {
783
793
  return false;
784
794
  }
@@ -2279,7 +2289,7 @@ static enum ggml_status ggml_metal_graph_compute(
2279
2289
  const int n_dims = ((int32_t *) dst->op_params)[1];
2280
2290
  const int mode = ((int32_t *) dst->op_params)[2];
2281
2291
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2282
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2292
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
2283
2293
 
2284
2294
  float freq_base;
2285
2295
  float freq_scale;
@@ -2296,22 +2306,23 @@ static enum ggml_status ggml_metal_graph_compute(
2296
2306
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2297
2307
 
2298
2308
  const bool is_neox = mode & 2;
2299
- const bool is_glm = mode & 4;
2300
2309
 
2301
- GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2310
+ id<MTLComputePipelineState> pipeline = nil;
2302
2311
 
2303
2312
  if (!is_neox) {
2304
- GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2313
+ switch (src0->type) {
2314
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2315
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2316
+ default: GGML_ASSERT(false);
2317
+ };
2318
+ } else {
2319
+ switch (src0->type) {
2320
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2321
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2322
+ default: GGML_ASSERT(false);
2323
+ };
2305
2324
  }
2306
2325
 
2307
- id<MTLComputePipelineState> pipeline = nil;
2308
-
2309
- switch (src0->type) {
2310
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2311
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2312
- default: GGML_ASSERT(false);
2313
- };
2314
-
2315
2326
  [encoder setComputePipelineState:pipeline];
2316
2327
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2317
2328
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2339,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
2339
2350
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2340
2351
  [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2341
2352
  [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2342
- [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2343
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2344
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2345
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2346
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2347
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2348
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2349
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2353
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2354
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2355
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2356
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2357
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2358
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2359
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2350
2360
 
2351
2361
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2352
2362
  } break;