llama_cpp 0.16.0 → 0.16.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +3 -0
- data/ext/llama_cpp/llama_cpp.cpp +14 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +119 -54
- data/vendor/tmp/llama.cpp/ggml-alloc.c +78 -22
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +190 -65
- data/vendor/tmp/llama.cpp/ggml-backend.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
- data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
- data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +1 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +21 -9
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +15 -1491
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +77 -62
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +77 -10
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +1 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +48 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +95 -129
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +8 -7
- data/vendor/tmp/llama.cpp/ggml-metal.m +17 -9
- data/vendor/tmp/llama.cpp/ggml-quants.c +982 -368
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +21 -15
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +2133 -13215
- data/vendor/tmp/llama.cpp/ggml-sycl.h +1 -10
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +28826 -25037
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +438 -493
- data/vendor/tmp/llama.cpp/ggml.c +158 -414
- data/vendor/tmp/llama.cpp/ggml.h +6 -0
- data/vendor/tmp/llama.cpp/llama.cpp +628 -279
- data/vendor/tmp/llama.cpp/llama.h +9 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +2 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +851 -801
- data/vendor/tmp/llama.cpp/unicode.cpp +33 -19
- data/vendor/tmp/llama.cpp/unicode.h +1 -1
- metadata +15 -3
|
@@ -1,9 +1,47 @@
|
|
|
1
1
|
#include "mmvq.cuh"
|
|
2
2
|
#include "vecdotq.cuh"
|
|
3
3
|
|
|
4
|
-
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
|
4
|
+
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
|
5
|
+
|
|
6
|
+
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
|
7
|
+
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
|
|
8
|
+
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
|
|
9
|
+
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
|
|
10
|
+
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
|
|
11
|
+
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
|
|
12
|
+
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
|
|
13
|
+
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
|
|
14
|
+
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
|
|
15
|
+
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
|
|
16
|
+
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
|
|
17
|
+
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
|
|
18
|
+
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
|
|
19
|
+
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
|
|
20
|
+
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
|
|
21
|
+
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
|
|
22
|
+
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
|
|
23
|
+
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
|
|
24
|
+
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
|
|
25
|
+
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
|
|
26
|
+
nullptr;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
30
|
+
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
|
|
31
|
+
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
|
|
32
|
+
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
|
|
33
|
+
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
|
|
34
|
+
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
|
|
35
|
+
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
|
|
36
|
+
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
|
|
37
|
+
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
|
|
38
|
+
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
|
|
39
|
+
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
|
|
40
|
+
type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
|
|
41
|
+
1;
|
|
42
|
+
}
|
|
5
43
|
|
|
6
|
-
template <
|
|
44
|
+
template <ggml_type type, int ncols_y>
|
|
7
45
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
8
46
|
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
|
9
47
|
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
|
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
|
|
|
12
50
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
|
13
51
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
|
14
52
|
|
|
53
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
54
|
+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
|
55
|
+
constexpr int vdr = get_vdr_mmvq(type);
|
|
56
|
+
|
|
57
|
+
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
|
58
|
+
|
|
15
59
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
|
16
60
|
constexpr int nwarps = 1;
|
|
17
61
|
constexpr int rows_per_cuda_block = 1;
|
|
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
|
|
|
29
73
|
// partial sum for each thread
|
|
30
74
|
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
|
31
75
|
|
|
32
|
-
const block_q_t * x = (const block_q_t *) vx;
|
|
33
76
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
|
34
77
|
|
|
35
78
|
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
|
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
|
|
|
42
85
|
for (int j = 0; j < ncols_y; ++j) {
|
|
43
86
|
#pragma unroll
|
|
44
87
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
|
45
|
-
tmp[j][i] += vec_dot_q_cuda(
|
|
46
|
-
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
|
|
88
|
+
tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
|
|
47
89
|
}
|
|
48
90
|
}
|
|
49
91
|
}
|
|
@@ -75,18 +117,18 @@ static __global__ void mul_mat_vec_q(
|
|
|
75
117
|
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
|
|
76
118
|
}
|
|
77
119
|
|
|
78
|
-
if (threadIdx.x < rows_per_cuda_block) {
|
|
120
|
+
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
|
|
79
121
|
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
|
80
122
|
}
|
|
81
123
|
}
|
|
82
124
|
}
|
|
83
125
|
|
|
84
|
-
template <
|
|
126
|
+
template <ggml_type type>
|
|
85
127
|
static void mul_mat_vec_q_cuda(
|
|
86
128
|
const void * vx, const void * vy, float * dst,
|
|
87
129
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
88
130
|
|
|
89
|
-
GGML_ASSERT(ncols_x %
|
|
131
|
+
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
|
90
132
|
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
|
91
133
|
|
|
92
134
|
int id = ggml_cuda_get_device();
|
|
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
|
|
|
124
166
|
|
|
125
167
|
switch (ncols_y) {
|
|
126
168
|
case 1:
|
|
127
|
-
mul_mat_vec_q<1,
|
|
128
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
169
|
+
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
129
170
|
break;
|
|
130
171
|
case 2:
|
|
131
|
-
mul_mat_vec_q<2,
|
|
132
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
172
|
+
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
133
173
|
break;
|
|
134
174
|
case 3:
|
|
135
|
-
mul_mat_vec_q<3,
|
|
136
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
175
|
+
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
137
176
|
break;
|
|
138
177
|
case 4:
|
|
139
|
-
mul_mat_vec_q<4,
|
|
140
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
178
|
+
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
141
179
|
break;
|
|
142
180
|
case 5:
|
|
143
|
-
mul_mat_vec_q<5,
|
|
144
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
181
|
+
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
145
182
|
break;
|
|
146
183
|
case 6:
|
|
147
|
-
mul_mat_vec_q<6,
|
|
148
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
184
|
+
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
149
185
|
break;
|
|
150
186
|
case 7:
|
|
151
|
-
mul_mat_vec_q<7,
|
|
152
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
187
|
+
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
153
188
|
break;
|
|
154
189
|
case 8:
|
|
155
|
-
mul_mat_vec_q<8,
|
|
156
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
190
|
+
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
|
157
191
|
break;
|
|
158
192
|
default:
|
|
159
193
|
GGML_ASSERT(false);
|
|
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
|
|
|
165
199
|
const void * vx, const void * vy, float * dst,
|
|
166
200
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
167
201
|
|
|
168
|
-
mul_mat_vec_q_cuda<
|
|
169
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
202
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
170
203
|
}
|
|
171
204
|
|
|
172
205
|
static void mul_mat_vec_q4_1_q8_1_cuda(
|
|
173
206
|
const void * vx, const void * vy, float * dst,
|
|
174
207
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
175
208
|
|
|
176
|
-
mul_mat_vec_q_cuda<
|
|
177
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
209
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
178
210
|
}
|
|
179
211
|
|
|
180
212
|
static void mul_mat_vec_q5_0_q8_1_cuda(
|
|
181
213
|
const void * vx, const void * vy, float * dst,
|
|
182
214
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
183
215
|
|
|
184
|
-
mul_mat_vec_q_cuda<
|
|
185
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
216
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
186
217
|
}
|
|
187
218
|
|
|
188
219
|
static void mul_mat_vec_q5_1_q8_1_cuda(
|
|
189
220
|
const void * vx, const void * vy, float * dst,
|
|
190
221
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
191
222
|
|
|
192
|
-
mul_mat_vec_q_cuda<
|
|
193
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
223
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
194
224
|
}
|
|
195
225
|
|
|
196
226
|
static void mul_mat_vec_q8_0_q8_1_cuda(
|
|
197
227
|
const void * vx, const void * vy, float * dst,
|
|
198
228
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
199
229
|
|
|
200
|
-
mul_mat_vec_q_cuda<
|
|
201
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
230
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
202
231
|
}
|
|
203
232
|
|
|
204
233
|
static void mul_mat_vec_q2_K_q8_1_cuda(
|
|
205
234
|
const void * vx, const void * vy, float * dst,
|
|
206
235
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
207
236
|
|
|
208
|
-
mul_mat_vec_q_cuda<
|
|
209
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
237
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
210
238
|
}
|
|
211
239
|
|
|
212
240
|
static void mul_mat_vec_q3_K_q8_1_cuda(
|
|
213
241
|
const void * vx, const void * vy, float * dst,
|
|
214
242
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
215
243
|
|
|
216
|
-
mul_mat_vec_q_cuda<
|
|
217
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
244
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
218
245
|
}
|
|
219
246
|
|
|
220
247
|
static void mul_mat_vec_q4_K_q8_1_cuda(
|
|
221
248
|
const void * vx, const void * vy, float * dst,
|
|
222
249
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
223
250
|
|
|
224
|
-
mul_mat_vec_q_cuda<
|
|
225
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
251
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
226
252
|
}
|
|
227
253
|
|
|
228
254
|
static void mul_mat_vec_q5_K_q8_1_cuda(
|
|
229
255
|
const void * vx, const void * vy, float * dst,
|
|
230
256
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
231
257
|
|
|
232
|
-
mul_mat_vec_q_cuda<
|
|
233
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
258
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
234
259
|
}
|
|
235
260
|
|
|
236
261
|
static void mul_mat_vec_q6_K_q8_1_cuda(
|
|
237
262
|
const void * vx, const void * vy, float * dst,
|
|
238
263
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
239
264
|
|
|
240
|
-
mul_mat_vec_q_cuda<
|
|
241
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
265
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
242
266
|
}
|
|
243
267
|
|
|
244
268
|
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
|
|
245
269
|
const void * vx, const void * vy, float * dst,
|
|
246
270
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
247
271
|
|
|
248
|
-
mul_mat_vec_q_cuda<
|
|
249
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
272
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
250
273
|
}
|
|
251
274
|
|
|
252
275
|
static void mul_mat_vec_iq2_xs_q8_1_cuda(
|
|
253
276
|
const void * vx, const void * vy, float * dst,
|
|
254
277
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
255
278
|
|
|
256
|
-
mul_mat_vec_q_cuda<
|
|
257
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
279
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
258
280
|
}
|
|
259
281
|
|
|
260
282
|
static void mul_mat_vec_iq2_s_q8_1_cuda(
|
|
261
283
|
const void * vx, const void * vy, float * dst,
|
|
262
284
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
263
285
|
|
|
264
|
-
mul_mat_vec_q_cuda<
|
|
265
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
286
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
266
287
|
}
|
|
267
288
|
|
|
268
289
|
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
|
|
269
290
|
const void * vx, const void * vy, float * dst,
|
|
270
291
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
271
292
|
|
|
272
|
-
mul_mat_vec_q_cuda<
|
|
273
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
293
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
274
294
|
}
|
|
275
295
|
|
|
276
296
|
static void mul_mat_vec_iq1_s_q8_1_cuda(
|
|
277
297
|
const void * vx, const void * vy, float * dst,
|
|
278
298
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
279
299
|
|
|
280
|
-
mul_mat_vec_q_cuda<
|
|
281
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
300
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
282
301
|
}
|
|
283
302
|
|
|
284
303
|
static void mul_mat_vec_iq1_m_q8_1_cuda(
|
|
285
304
|
const void * vx, const void * vy, float * dst,
|
|
286
305
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
287
306
|
|
|
288
|
-
mul_mat_vec_q_cuda<
|
|
289
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
307
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
290
308
|
}
|
|
291
309
|
|
|
292
310
|
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
|
293
311
|
const void * vx, const void * vy, float * dst,
|
|
294
312
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
295
313
|
|
|
296
|
-
mul_mat_vec_q_cuda<
|
|
297
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
314
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
298
315
|
}
|
|
299
316
|
|
|
300
317
|
static void mul_mat_vec_iq4_xs_q8_1_cuda(
|
|
301
318
|
const void * vx, const void * vy, float * dst,
|
|
302
319
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
303
320
|
|
|
304
|
-
mul_mat_vec_q_cuda<
|
|
305
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
321
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
306
322
|
}
|
|
307
323
|
|
|
308
324
|
static void mul_mat_vec_iq3_s_q8_1_cuda(
|
|
309
325
|
const void * vx, const void * vy, float * dst,
|
|
310
326
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
|
311
327
|
|
|
312
|
-
mul_mat_vec_q_cuda<
|
|
313
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
328
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
|
314
329
|
}
|
|
315
330
|
|
|
316
331
|
void ggml_cuda_op_mul_mat_vec_q(
|
|
@@ -1,22 +1,23 @@
|
|
|
1
1
|
#include "quantize.cuh"
|
|
2
|
+
#include <cstdint>
|
|
2
3
|
|
|
3
|
-
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t
|
|
4
|
-
const int64_t
|
|
4
|
+
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
|
|
5
|
+
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
5
6
|
|
|
6
|
-
if (
|
|
7
|
+
if (ix0 >= kx0_padded) {
|
|
7
8
|
return;
|
|
8
9
|
}
|
|
9
10
|
|
|
10
|
-
const int64_t
|
|
11
|
+
const int64_t ix1 = blockIdx.y;
|
|
11
12
|
|
|
12
|
-
const int64_t i_padded =
|
|
13
|
+
const int64_t i_padded = ix1*kx0_padded + ix0;
|
|
13
14
|
|
|
14
15
|
block_q8_1 * y = (block_q8_1 *) vy;
|
|
15
16
|
|
|
16
17
|
const int64_t ib = i_padded / QK8_1; // block index
|
|
17
18
|
const int64_t iqs = i_padded % QK8_1; // quant index
|
|
18
19
|
|
|
19
|
-
const float xi =
|
|
20
|
+
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
|
|
20
21
|
float amax = fabsf(xi);
|
|
21
22
|
float sum = xi;
|
|
22
23
|
|
|
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
|
36
37
|
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
|
37
38
|
}
|
|
38
39
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
const
|
|
40
|
+
template <bool need_sum>
|
|
41
|
+
static __global__ void quantize_mmq_q8_1(
|
|
42
|
+
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
|
43
|
+
|
|
44
|
+
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
|
45
|
+
|
|
46
|
+
if (ix0 >= kx0_padded) {
|
|
47
|
+
return;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
|
51
|
+
|
|
52
|
+
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
|
53
|
+
|
|
54
|
+
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
|
|
55
|
+
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
|
56
|
+
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
|
57
|
+
|
|
58
|
+
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
|
|
59
|
+
float amax = fabsf(xi);
|
|
60
|
+
|
|
61
|
+
amax = warp_reduce_max(amax);
|
|
62
|
+
|
|
63
|
+
float sum;
|
|
64
|
+
if (need_sum) {
|
|
65
|
+
sum = warp_reduce_sum(xi);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
const float d = amax / 127;
|
|
69
|
+
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
|
70
|
+
|
|
71
|
+
y[ib].qs[iqs] = q;
|
|
72
|
+
|
|
73
|
+
if (iqs % QK8_1 != 0) {
|
|
74
|
+
return;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if (need_sum) {
|
|
78
|
+
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
|
|
79
|
+
} else {
|
|
80
|
+
((float *) y[ib].ds)[iqs/QK8_1] = d;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void quantize_row_q8_1_cuda(
|
|
85
|
+
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
|
86
|
+
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
|
87
|
+
|
|
88
|
+
GGML_ASSERT(kx0_padded % QK8_1 == 0);
|
|
89
|
+
|
|
90
|
+
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
|
91
|
+
const dim3 num_blocks(block_num_x, kx1*channels, 1);
|
|
42
92
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
|
43
|
-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy,
|
|
93
|
+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
|
|
94
|
+
|
|
95
|
+
GGML_UNUSED(type_x);
|
|
44
96
|
}
|
|
45
97
|
|
|
98
|
+
void quantize_mmq_q8_1_cuda(
|
|
99
|
+
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
|
100
|
+
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
|
101
|
+
|
|
102
|
+
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
|
103
|
+
|
|
104
|
+
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
|
105
|
+
const dim3 num_blocks(block_num_x, kx1, channels);
|
|
106
|
+
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
|
107
|
+
if (mmq_need_sum(type_x)) {
|
|
108
|
+
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
|
109
|
+
} else {
|
|
110
|
+
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
|
111
|
+
}
|
|
112
|
+
}
|
|
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
|
130
130
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
131
131
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
132
132
|
|
|
133
|
+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
|
133
134
|
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
|
134
135
|
switch (ncols_x) {
|
|
135
136
|
case 32:
|