llama_cpp 0.16.0 → 0.16.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (134) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/ext/llama_cpp/extconf.rb +2 -0
  4. data/ext/llama_cpp/llama_cpp.cpp +2 -0
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +2 -0
  7. data/vendor/tmp/llama.cpp/Makefile +110 -53
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +78 -22
  9. data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
  10. data/vendor/tmp/llama.cpp/ggml-backend.c +178 -64
  11. data/vendor/tmp/llama.cpp/ggml-backend.h +3 -3
  12. data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
  13. data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
  14. data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +1 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +21 -9
  17. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +1 -1
  18. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +15 -1491
  19. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +76 -61
  20. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +77 -10
  21. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +1 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +1 -1
  23. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +1 -1
  24. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +1 -1
  25. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +1 -1
  26. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +1 -1
  27. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +1 -1
  28. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +1 -1
  29. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +1 -1
  30. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +1 -1
  31. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +1 -1
  32. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +1 -1
  33. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +1 -1
  34. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +1 -1
  35. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +1 -1
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +1 -1
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +1 -1
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +1 -1
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +1 -1
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +1 -1
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +1 -1
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +1 -1
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +1 -1
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +1 -1
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +1 -1
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +1 -1
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +1 -1
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +1 -1
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +1 -1
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +1 -1
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +1 -1
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +1 -1
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +1 -1
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +1 -1
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +1 -1
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +1 -1
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +1 -1
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +1 -1
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +1 -1
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +1 -1
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +1 -1
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +1 -1
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +1 -1
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +1 -1
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +1 -1
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +1 -1
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +1 -1
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +1 -1
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +1 -1
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +1 -1
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +1 -1
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +1 -1
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +1 -1
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +1 -1
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +1 -1
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +1 -1
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +1 -1
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +1 -1
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +1 -1
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +1 -1
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +1 -1
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +1 -1
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +1 -1
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +1 -1
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +1 -1
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +1 -1
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +1 -1
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +1 -1
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +1 -1
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +1 -1
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +1 -1
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +1 -1
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +1 -1
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +1 -1
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +1 -1
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +1 -1
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +1 -1
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +1 -1
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +1 -1
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +1 -1
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +1 -1
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +1 -1
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +1 -1
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +1 -1
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +1 -1
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +1 -1
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +1 -1
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +1 -1
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +1 -1
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +1 -1
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +1 -1
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +1 -1
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +20 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda.cu +95 -129
  125. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +8 -7
  126. data/vendor/tmp/llama.cpp/ggml-metal.m +11 -9
  127. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +13 -12
  128. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +19 -23
  129. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +1230 -1129
  130. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +181 -148
  131. data/vendor/tmp/llama.cpp/ggml.c +102 -275
  132. data/vendor/tmp/llama.cpp/llama.cpp +103 -47
  133. data/vendor/tmp/llama.cpp/llama.h +4 -0
  134. metadata +15 -3
@@ -1,9 +1,47 @@
1
1
  #include "mmvq.cuh"
2
2
  #include "vecdotq.cuh"
3
3
 
4
- typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
4
+ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
5
+
6
+ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
7
+ return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
8
+ type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
9
+ type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
10
+ type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
11
+ type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
12
+ type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
13
+ type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
14
+ type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
15
+ type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
16
+ type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
17
+ type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
18
+ type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
19
+ type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
20
+ type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
21
+ type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
22
+ type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
23
+ type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
24
+ type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
25
+ type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
26
+ nullptr;
27
+ }
28
+
29
+ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
30
+ return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
31
+ type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
32
+ type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
33
+ type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
34
+ type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
35
+ type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
36
+ type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
37
+ type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
38
+ type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
39
+ type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
40
+ type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
41
+ 1;
42
+ }
5
43
 
6
- template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
44
+ template <ggml_type type, int ncols_y>
7
45
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
46
  // tell the compiler to use as many registers as it wants, see nwarps definition below
9
47
  __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
12
50
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
13
51
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
14
52
 
53
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
54
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
55
+ constexpr int vdr = get_vdr_mmvq(type);
56
+
57
+ constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
58
+
15
59
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
16
60
  constexpr int nwarps = 1;
17
61
  constexpr int rows_per_cuda_block = 1;
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
29
73
  // partial sum for each thread
30
74
  float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
31
75
 
32
- const block_q_t * x = (const block_q_t *) vx;
33
76
  const block_q8_1 * y = (const block_q8_1 *) vy;
34
77
 
35
78
  for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
42
85
  for (int j = 0; j < ncols_y; ++j) {
43
86
  #pragma unroll
44
87
  for (int i = 0; i < rows_per_cuda_block; ++i) {
45
- tmp[j][i] += vec_dot_q_cuda(
46
- &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
88
+ tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
47
89
  }
48
90
  }
49
91
  }
@@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q(
81
123
  }
82
124
  }
83
125
 
84
- template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
126
+ template <ggml_type type>
85
127
  static void mul_mat_vec_q_cuda(
86
128
  const void * vx, const void * vy, float * dst,
87
129
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
88
130
 
89
- GGML_ASSERT(ncols_x % qk == 0);
131
+ GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
90
132
  GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
91
133
 
92
134
  int id = ggml_cuda_get_device();
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
124
166
 
125
167
  switch (ncols_y) {
126
168
  case 1:
127
- mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
128
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
169
+ mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
129
170
  break;
130
171
  case 2:
131
- mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
132
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
172
+ mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
133
173
  break;
134
174
  case 3:
135
- mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
136
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
175
+ mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
137
176
  break;
138
177
  case 4:
139
- mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
140
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
178
+ mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
141
179
  break;
142
180
  case 5:
143
- mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
144
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
181
+ mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
145
182
  break;
146
183
  case 6:
147
- mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
148
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
184
+ mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
149
185
  break;
150
186
  case 7:
151
- mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
152
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
187
+ mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
153
188
  break;
154
189
  case 8:
155
- mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
156
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
190
+ mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
157
191
  break;
158
192
  default:
159
193
  GGML_ASSERT(false);
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
165
199
  const void * vx, const void * vy, float * dst,
166
200
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
167
201
 
168
- mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
169
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
202
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
170
203
  }
171
204
 
172
205
  static void mul_mat_vec_q4_1_q8_1_cuda(
173
206
  const void * vx, const void * vy, float * dst,
174
207
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
175
208
 
176
- mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
177
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
209
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
178
210
  }
179
211
 
180
212
  static void mul_mat_vec_q5_0_q8_1_cuda(
181
213
  const void * vx, const void * vy, float * dst,
182
214
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
183
215
 
184
- mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
185
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
216
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
186
217
  }
187
218
 
188
219
  static void mul_mat_vec_q5_1_q8_1_cuda(
189
220
  const void * vx, const void * vy, float * dst,
190
221
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
191
222
 
192
- mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
193
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
223
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
194
224
  }
195
225
 
196
226
  static void mul_mat_vec_q8_0_q8_1_cuda(
197
227
  const void * vx, const void * vy, float * dst,
198
228
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
199
229
 
200
- mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
201
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
230
+ mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
202
231
  }
203
232
 
204
233
  static void mul_mat_vec_q2_K_q8_1_cuda(
205
234
  const void * vx, const void * vy, float * dst,
206
235
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
207
236
 
208
- mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
209
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
237
+ mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
210
238
  }
211
239
 
212
240
  static void mul_mat_vec_q3_K_q8_1_cuda(
213
241
  const void * vx, const void * vy, float * dst,
214
242
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
215
243
 
216
- mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
217
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
244
+ mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
218
245
  }
219
246
 
220
247
  static void mul_mat_vec_q4_K_q8_1_cuda(
221
248
  const void * vx, const void * vy, float * dst,
222
249
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
223
250
 
224
- mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
225
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
251
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
226
252
  }
227
253
 
228
254
  static void mul_mat_vec_q5_K_q8_1_cuda(
229
255
  const void * vx, const void * vy, float * dst,
230
256
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
231
257
 
232
- mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
233
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
258
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
234
259
  }
235
260
 
236
261
  static void mul_mat_vec_q6_K_q8_1_cuda(
237
262
  const void * vx, const void * vy, float * dst,
238
263
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
239
264
 
240
- mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
241
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
265
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
242
266
  }
243
267
 
244
268
  static void mul_mat_vec_iq2_xxs_q8_1_cuda(
245
269
  const void * vx, const void * vy, float * dst,
246
270
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
247
271
 
248
- mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
249
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
272
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
250
273
  }
251
274
 
252
275
  static void mul_mat_vec_iq2_xs_q8_1_cuda(
253
276
  const void * vx, const void * vy, float * dst,
254
277
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
255
278
 
256
- mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
257
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
279
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
258
280
  }
259
281
 
260
282
  static void mul_mat_vec_iq2_s_q8_1_cuda(
261
283
  const void * vx, const void * vy, float * dst,
262
284
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
263
285
 
264
- mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
265
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
286
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
266
287
  }
267
288
 
268
289
  static void mul_mat_vec_iq3_xxs_q8_1_cuda(
269
290
  const void * vx, const void * vy, float * dst,
270
291
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
271
292
 
272
- mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
273
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
293
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
274
294
  }
275
295
 
276
296
  static void mul_mat_vec_iq1_s_q8_1_cuda(
277
297
  const void * vx, const void * vy, float * dst,
278
298
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
279
299
 
280
- mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
281
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
300
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
282
301
  }
283
302
 
284
303
  static void mul_mat_vec_iq1_m_q8_1_cuda(
285
304
  const void * vx, const void * vy, float * dst,
286
305
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
287
306
 
288
- mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
289
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
307
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
290
308
  }
291
309
 
292
310
  static void mul_mat_vec_iq4_nl_q8_1_cuda(
293
311
  const void * vx, const void * vy, float * dst,
294
312
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
295
313
 
296
- mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
297
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
314
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
298
315
  }
299
316
 
300
317
  static void mul_mat_vec_iq4_xs_q8_1_cuda(
301
318
  const void * vx, const void * vy, float * dst,
302
319
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
303
320
 
304
- mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
305
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
321
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
306
322
  }
307
323
 
308
324
  static void mul_mat_vec_iq3_s_q8_1_cuda(
309
325
  const void * vx, const void * vy, float * dst,
310
326
  const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
311
327
 
312
- mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
313
- (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
328
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
314
329
  }
315
330
 
316
331
  void ggml_cuda_op_mul_mat_vec_q(
@@ -1,22 +1,23 @@
1
1
  #include "quantize.cuh"
2
+ #include <cstdint>
2
3
 
3
- static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
4
- const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
4
+ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
5
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
5
6
 
6
- if (ix >= kx_padded) {
7
+ if (ix0 >= kx0_padded) {
7
8
  return;
8
9
  }
9
10
 
10
- const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
11
+ const int64_t ix1 = blockIdx.y;
11
12
 
12
- const int64_t i_padded = (int64_t)iy*kx_padded + ix;
13
+ const int64_t i_padded = ix1*kx0_padded + ix0;
13
14
 
14
15
  block_q8_1 * y = (block_q8_1 *) vy;
15
16
 
16
17
  const int64_t ib = i_padded / QK8_1; // block index
17
18
  const int64_t iqs = i_padded % QK8_1; // quant index
18
19
 
19
- const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
20
+ const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
20
21
  float amax = fabsf(xi);
21
22
  float sum = xi;
22
23
 
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
36
37
  reinterpret_cast<half&>(y[ib].ds.y) = sum;
37
38
  }
38
39
 
39
- void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
40
- const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
41
- const dim3 num_blocks(block_num_x, ky, 1);
40
+ template <bool need_sum>
41
+ static __global__ void quantize_mmq_q8_1(
42
+ const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
43
+
44
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
45
+
46
+ if (ix0 >= kx0_padded) {
47
+ return;
48
+ }
49
+
50
+ const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
51
+
52
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
53
+
54
+ const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
55
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
56
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
57
+
58
+ const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
59
+ float amax = fabsf(xi);
60
+
61
+ amax = warp_reduce_max(amax);
62
+
63
+ float sum;
64
+ if (need_sum) {
65
+ sum = warp_reduce_sum(xi);
66
+ }
67
+
68
+ const float d = amax / 127;
69
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
70
+
71
+ y[ib].qs[iqs] = q;
72
+
73
+ if (iqs % QK8_1 != 0) {
74
+ return;
75
+ }
76
+
77
+ if (need_sum) {
78
+ y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
79
+ } else {
80
+ ((float *) y[ib].ds)[iqs/QK8_1] = d;
81
+ }
82
+ }
83
+
84
+ void quantize_row_q8_1_cuda(
85
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
86
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
87
+
88
+ GGML_ASSERT(kx0_padded % QK8_1 == 0);
89
+
90
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
91
+ const dim3 num_blocks(block_num_x, kx1*channels, 1);
42
92
  const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
43
- quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
93
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
94
+
95
+ GGML_UNUSED(type_x);
44
96
  }
45
97
 
98
+ void quantize_mmq_q8_1_cuda(
99
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
100
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
101
+
102
+ GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
103
+
104
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
105
+ const dim3 num_blocks(block_num_x, kx1, channels);
106
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
107
+ if (mmq_need_sum(type_x)) {
108
+ quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
109
+ } else {
110
+ quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
111
+ }
112
+ }
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
130
130
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
131
131
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
132
 
133
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
133
134
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
134
135
  switch (ncols_x) {
135
136
  case 32:
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4
 
@@ -1,4 +1,4 @@
1
- // This file has been autogenerated by generate-variants.py, do not edit manually.
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
2
 
3
3
  #include "../fattn-vec-f16.cuh"
4
4