llama_cpp 0.16.0 → 0.16.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +13 -0
  3. data/ext/llama_cpp/extconf.rb +3 -0
  4. data/ext/llama_cpp/llama_cpp.cpp +14 -0
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +4 -0
  7. data/vendor/tmp/llama.cpp/Makefile +119 -54
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +78 -22
  9. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
  10. data/vendor/tmp/llama.cpp/ggml-backend.c +190 -65
  11. data/vendor/tmp/llama.cpp/ggml-backend.h +6 -3
  12. data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
  13. data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
  14. data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +1 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +21 -9
  17. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +1 -1
  18. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +15 -1491
  19. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +77 -62
  20. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +77 -10
  21. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +1 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +1 -1
  23. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +1 -1
  24. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +1 -1
  25. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +1 -1
  26. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +1 -1
  27. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +1 -1
  28. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +1 -1
  29. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +1 -1
  30. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +1 -1
  31. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +1 -1
  32. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +1 -1
  33. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +1 -1
  34. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +1 -1
  35. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +1 -1
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +1 -1
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +1 -1
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +1 -1
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +1 -1
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +1 -1
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +1 -1
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +1 -1
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +1 -1
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +1 -1
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +1 -1
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +1 -1
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +1 -1
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +1 -1
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +1 -1
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +1 -1
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +1 -1
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +1 -1
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +1 -1
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +1 -1
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +1 -1
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +1 -1
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +1 -1
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +1 -1
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +1 -1
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +1 -1
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +1 -1
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +1 -1
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +1 -1
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +1 -1
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +1 -1
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +1 -1
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +1 -1
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +1 -1
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +1 -1
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +1 -1
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +1 -1
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +1 -1
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +1 -1
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +1 -1
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +1 -1
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +1 -1
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +1 -1
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +1 -1
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +1 -1
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +1 -1
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +1 -1
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +1 -1
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +1 -1
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +1 -1
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +1 -1
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +1 -1
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +1 -1
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +1 -1
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +1 -1
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +1 -1
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +1 -1
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +1 -1
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +1 -1
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +1 -1
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +1 -1
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +1 -1
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +1 -1
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +1 -1
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +1 -1
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +1 -1
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +1 -1
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +1 -1
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +1 -1
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +1 -1
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +1 -1
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +1 -1
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +1 -1
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +1 -1
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +1 -1
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +1 -1
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +1 -1
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +1 -1
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +48 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda.cu +95 -129
  125. data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
  126. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +8 -7
  127. data/vendor/tmp/llama.cpp/ggml-metal.m +17 -9
  128. data/vendor/tmp/llama.cpp/ggml-quants.c +982 -368
  129. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +21 -15
  130. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +2133 -13215
  131. data/vendor/tmp/llama.cpp/ggml-sycl.h +1 -10
  132. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +28826 -25037
  133. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +438 -493
  134. data/vendor/tmp/llama.cpp/ggml.c +158 -414
  135. data/vendor/tmp/llama.cpp/ggml.h +6 -0
  136. data/vendor/tmp/llama.cpp/llama.cpp +628 -279
  137. data/vendor/tmp/llama.cpp/llama.h +9 -1
  138. data/vendor/tmp/llama.cpp/sgemm.cpp +2 -0
  139. data/vendor/tmp/llama.cpp/unicode-data.cpp +851 -801
  140. data/vendor/tmp/llama.cpp/unicode.cpp +33 -19
  141. data/vendor/tmp/llama.cpp/unicode.h +1 -1
  142. metadata +15 -3
@@ -1,9 +1,47 @@
1
1
  #include "mmvq.cuh"
2
2
  #include "vecdotq.cuh"
3
3
 
4
- typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
4
+ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
5
+
6
+ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
7
+ return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
8
+ type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
9
+ type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
10
+ type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
11
+ type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
12
+ type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
13
+ type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
14
+ type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
15
+ type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
16
+ type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
17
+ type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
18
+ type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
19
+ type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
20
+ type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
21
+ type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
22
+ type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
23
+ type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
24
+ type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
25
+ type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
26
+ nullptr;
27
+ }
28
+
29
+ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
30
+ return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
31
+ type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
32
+ type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
33
+ type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
34
+ type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
35
+ type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
36
+ type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
37
+ type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
38
+ type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
39
+ type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
40
+ type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
41
+ 1;
42
+ }
5
43
 
6
- template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
44
+ template <ggml_type type, int ncols_y>
7
45
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
46
  // tell the compiler to use as many registers as it wants, see nwarps definition below
9
47
  __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
12
50
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
13
51
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
14
52
 
53
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
54
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
55
+ constexpr int vdr = get_vdr_mmvq(type);
56
+
57
+ constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
58
+
15
59
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
16
60
  constexpr int nwarps = 1;
17
61
  constexpr int rows_per_cuda_block = 1;
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
29
73
  // partial sum for each thread
30
74
  float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
31
75
 
32
- const block_q_t * x = (const block_q_t *) vx;
33
76
  const block_q8_1 * y = (const block_q8_1 *) vy;
34
77
 
35
78
  for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
42
85
  for (int j = 0; j < ncols_y; ++j) {
43
86
  #pragma unroll
44
87
  for (int i = 0; i < rows_per_cuda_block; ++i) {
45
- tmp[j][i] += vec_dot_q_cuda(
46
- &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
88
+ tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
47
89
  }
48
90
  }
49
91
  }
@@ -75,18 +117,18 @@ static __global__ void mul_mat_vec_q(
75
117
  tmp[j][i] = warp_reduce_sum(tmp[j][i]);
76
118
  }
77
119
 
78
- if (threadIdx.x < rows_per_cuda_block) {
120
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
79
121
  dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
80
122
  }
81
123
  }
82
124
  }
83
125
 
84
- template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
126
+ template <ggml_type type>
85
127
  static void mul_mat_vec_q_cuda(
86
128
  const void * vx, const void * vy, float * dst,
87
129
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
88
130
 
89
- GGML_ASSERT(ncols_x % qk == 0);
131
+ GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
90
132
  GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
91
133
 
92
134
  int id = ggml_cuda_get_device();
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
124
166
 
125
167
  switch (ncols_y) {
126
168
  case 1:
127
- mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
128
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
169
+ mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
129
170
  break;
130
171
  case 2:
131
- mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
132
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
172
+ mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
133
173
  break;
134
174
  case 3:
135
- mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
136
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
175
+ mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
137
176
  break;
138
177
  case 4:
139
- mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
140
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
178
+ mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
141
179
  break;
142
180
  case 5:
143
- mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
144
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
181
+ mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
145
182
  break;
146
183
  case 6:
147
- mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
148
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
184
+ mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
149
185
  break;
150
186
  case 7:
151
- mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
152
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
187
+ mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
153
188
  break;
154
189
  case 8:
155
- mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
156
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
190
+ mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
157
191
  break;
158
192
  default:
159
193
  GGML_ASSERT(false);
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
165
199
  const void * vx, const void * vy, float * dst,
166
200
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
167
201
 
168
- mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
169
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
202
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
170
203
  }
171
204
 
172
205
  static void mul_mat_vec_q4_1_q8_1_cuda(
173
206
  const void * vx, const void * vy, float * dst,
174
207
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
175
208
 
176
- mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
177
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
209
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
178
210
  }
179
211
 
180
212
  static void mul_mat_vec_q5_0_q8_1_cuda(
181
213
  const void * vx, const void * vy, float * dst,
182
214
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
183
215
 
184
- mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
185
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
216
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
186
217
  }
187
218
 
188
219
  static void mul_mat_vec_q5_1_q8_1_cuda(
189
220
  const void * vx, const void * vy, float * dst,
190
221
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
191
222
 
192
- mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
193
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
223
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
194
224
  }
195
225
 
196
226
  static void mul_mat_vec_q8_0_q8_1_cuda(
197
227
  const void * vx, const void * vy, float * dst,
198
228
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
199
229
 
200
- mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
201
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
230
+ mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
202
231
  }
203
232
 
204
233
  static void mul_mat_vec_q2_K_q8_1_cuda(
205
234
  const void * vx, const void * vy, float * dst,
206
235
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
207
236
 
208
- mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
209
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
237
+ mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
210
238
  }
211
239
 
212
240
  static void mul_mat_vec_q3_K_q8_1_cuda(
213
241
  const void * vx, const void * vy, float * dst,
214
242
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
215
243
 
216
- mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
217
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
244
+ mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
218
245
  }
219
246
 
220
247
  static void mul_mat_vec_q4_K_q8_1_cuda(
221
248
  const void * vx, const void * vy, float * dst,
222
249
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
223
250
 
224
- mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
225
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
251
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
226
252
  }
227
253
 
228
254
  static void mul_mat_vec_q5_K_q8_1_cuda(
229
255
  const void * vx, const void * vy, float * dst,
230
256
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
231
257
 
232
- mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
233
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
258
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
234
259
  }
235
260
 
236
261
  static void mul_mat_vec_q6_K_q8_1_cuda(
237
262
  const void * vx, const void * vy, float * dst,
238
263
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
239
264
 
240
- mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
241
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
265
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
242
266
  }
243
267
 
244
268
  static void mul_mat_vec_iq2_xxs_q8_1_cuda(
245
269
  const void * vx, const void * vy, float * dst,
246
270
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
247
271
 
248
- mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
249
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
272
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
250
273
  }
251
274
 
252
275
  static void mul_mat_vec_iq2_xs_q8_1_cuda(
253
276
  const void * vx, const void * vy, float * dst,
254
277
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
255
278
 
256
- mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
257
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
279
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
258
280
  }
259
281
 
260
282
  static void mul_mat_vec_iq2_s_q8_1_cuda(
261
283
  const void * vx, const void * vy, float * dst,
262
284
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
263
285
 
264
- mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
265
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
286
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
266
287
  }
267
288
 
268
289
  static void mul_mat_vec_iq3_xxs_q8_1_cuda(
269
290
  const void * vx, const void * vy, float * dst,
270
291
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
271
292
 
272
- mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
273
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
293
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
274
294
  }
275
295
 
276
296
  static void mul_mat_vec_iq1_s_q8_1_cuda(
277
297
  const void * vx, const void * vy, float * dst,
278
298
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
279
299
 
280
- mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
281
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
300
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
282
301
  }
283
302
 
284
303
  static void mul_mat_vec_iq1_m_q8_1_cuda(
285
304
  const void * vx, const void * vy, float * dst,
286
305
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
287
306
 
288
- mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
289
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
307
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
290
308
  }
291
309
 
292
310
  static void mul_mat_vec_iq4_nl_q8_1_cuda(
293
311
  const void * vx, const void * vy, float * dst,
294
312
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
295
313
 
296
- mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
297
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
314
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
298
315
  }
299
316
 
300
317
  static void mul_mat_vec_iq4_xs_q8_1_cuda(
301
318
  const void * vx, const void * vy, float * dst,
302
319
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
303
320
 
304
- mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
305
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
321
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
306
322
  }
307
323
 
308
324
  static void mul_mat_vec_iq3_s_q8_1_cuda(
309
325
  const void * vx, const void * vy, float * dst,
310
326
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
311
327
 
312
- mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
313
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
328
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
314
329
  }
315
330
 
316
331
  void ggml_cuda_op_mul_mat_vec_q(
@@ -1,22 +1,23 @@
1
1
  #include "quantize.cuh"
2
+ #include <cstdint>
2
3
 
3
- static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
4
- const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
4
+ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
5
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
5
6
 
6
- if (ix >= kx_padded) {
7
+ if (ix0 >= kx0_padded) {
7
8
  return;
8
9
  }
9
10
 
10
- const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
11
+ const int64_t ix1 = blockIdx.y;
11
12
 
12
- const int64_t i_padded = (int64_t)iy*kx_padded + ix;
13
+ const int64_t i_padded = ix1*kx0_padded + ix0;
13
14
 
14
15
  block_q8_1 * y = (block_q8_1 *) vy;
15
16
 
16
17
  const int64_t ib = i_padded / QK8_1; // block index
17
18
  const int64_t iqs = i_padded % QK8_1; // quant index
18
19
 
19
- const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
20
+ const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
20
21
  float amax = fabsf(xi);
21
22
  float sum = xi;
22
23
 
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
36
37
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
37
38
  }
38
39
 
39
- void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
40
- const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
41
- const dim3 num_blocks(block_num_x, ky, 1);
40
+ template <bool need_sum>
41
+ static __global__ void quantize_mmq_q8_1(
42
+ const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
43
+
44
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
45
+
46
+ if (ix0 >= kx0_padded) {
47
+ return;
48
+ }
49
+
50
+ const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
51
+
52
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
53
+
54
+ const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
55
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
56
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
57
+
58
+ const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
59
+ float amax = fabsf(xi);
60
+
61
+ amax = warp_reduce_max(amax);
62
+
63
+ float sum;
64
+ if (need_sum) {
65
+ sum = warp_reduce_sum(xi);
66
+ }
67
+
68
+ const float d = amax / 127;
69
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
70
+
71
+ y[ib].qs[iqs] = q;
72
+
73
+ if (iqs % QK8_1 != 0) {
74
+ return;
75
+ }
76
+
77
+ if (need_sum) {
78
+ y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
79
+ } else {
80
+ ((float *) y[ib].ds)[iqs/QK8_1] = d;
81
+ }
82
+ }
83
+
84
+ void quantize_row_q8_1_cuda(
85
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
86
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
87
+
88
+ GGML_ASSERT(kx0_padded % QK8_1 == 0);
89
+
90
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
91
+ const dim3 num_blocks(block_num_x, kx1*channels, 1);
42
92
  const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
43
- quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
93
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
94
+
95
+ GGML_UNUSED(type_x);
44
96
  }
45
97
 
98
+ void quantize_mmq_q8_1_cuda(
99
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
100
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
101
+
102
+ GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
103
+
104
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
105
+ const dim3 num_blocks(block_num_x, kx1, channels);
106
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
107
+ if (mmq_need_sum(type_x)) {
108
+ quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
109
+ } else {
110
+ quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
111
+ }
112
+ }
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
130
130
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
131
131
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
132
 
133
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
133
134
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
134
135
  switch (ncols_x) {
135
136
  case 32:
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4