llama_cpp 0.16.0 → 0.16.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/extconf.rb +3 -0
- data/ext/llama_cpp/llama_cpp.cpp +14 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +119 -54
- data/vendor/tmp/llama.cpp/ggml-alloc.c +78 -22
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +20 -8
- data/vendor/tmp/llama.cpp/ggml-backend.c +190 -65
- data/vendor/tmp/llama.cpp/ggml-backend.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-blas.cpp +363 -0
- data/vendor/tmp/llama.cpp/ggml-blas.h +23 -0
- data/vendor/tmp/llama.cpp/ggml-common.h +6 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +1 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +21 -9
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +15 -1491
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +77 -62
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +77 -10
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +1 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +48 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +95 -129
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +8 -7
- data/vendor/tmp/llama.cpp/ggml-metal.m +17 -9
- data/vendor/tmp/llama.cpp/ggml-quants.c +982 -368
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +21 -15
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +2133 -13215
- data/vendor/tmp/llama.cpp/ggml-sycl.h +1 -10
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +28826 -25037
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +438 -493
- data/vendor/tmp/llama.cpp/ggml.c +158 -414
- data/vendor/tmp/llama.cpp/ggml.h +6 -0
- data/vendor/tmp/llama.cpp/llama.cpp +628 -279
- data/vendor/tmp/llama.cpp/llama.h +9 -1
- data/vendor/tmp/llama.cpp/sgemm.cpp +2 -0
- data/vendor/tmp/llama.cpp/unicode-data.cpp +851 -801
- data/vendor/tmp/llama.cpp/unicode.cpp +33 -19
- data/vendor/tmp/llama.cpp/unicode.h +1 -1
- metadata +15 -3
@@ -1,9 +1,47 @@
|
|
1
1
|
#include "mmvq.cuh"
|
2
2
|
#include "vecdotq.cuh"
|
3
3
|
|
4
|
-
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
4
|
+
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
5
|
+
|
6
|
+
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
7
|
+
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
|
8
|
+
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
|
9
|
+
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
|
10
|
+
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
|
11
|
+
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
|
12
|
+
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
|
13
|
+
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
|
14
|
+
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
|
15
|
+
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
|
16
|
+
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
|
17
|
+
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
|
18
|
+
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
|
19
|
+
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
|
20
|
+
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
|
21
|
+
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
|
22
|
+
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
|
23
|
+
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
|
24
|
+
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
|
25
|
+
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
|
26
|
+
nullptr;
|
27
|
+
}
|
28
|
+
|
29
|
+
static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
30
|
+
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
|
31
|
+
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
|
32
|
+
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
|
33
|
+
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
|
34
|
+
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
|
35
|
+
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
|
36
|
+
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
|
37
|
+
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
|
38
|
+
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
|
39
|
+
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
|
40
|
+
type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
|
41
|
+
1;
|
42
|
+
}
|
5
43
|
|
6
|
-
template <
|
44
|
+
template <ggml_type type, int ncols_y>
|
7
45
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
8
46
|
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
9
47
|
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
@@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q(
|
|
12
50
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
13
51
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
14
52
|
|
53
|
+
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
54
|
+
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
55
|
+
constexpr int vdr = get_vdr_mmvq(type);
|
56
|
+
|
57
|
+
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
58
|
+
|
15
59
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
16
60
|
constexpr int nwarps = 1;
|
17
61
|
constexpr int rows_per_cuda_block = 1;
|
@@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q(
|
|
29
73
|
// partial sum for each thread
|
30
74
|
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
31
75
|
|
32
|
-
const block_q_t * x = (const block_q_t *) vx;
|
33
76
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
34
77
|
|
35
78
|
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
@@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q(
|
|
42
85
|
for (int j = 0; j < ncols_y; ++j) {
|
43
86
|
#pragma unroll
|
44
87
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
45
|
-
tmp[j][i] += vec_dot_q_cuda(
|
46
|
-
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
|
88
|
+
tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
|
47
89
|
}
|
48
90
|
}
|
49
91
|
}
|
@@ -75,18 +117,18 @@ static __global__ void mul_mat_vec_q(
|
|
75
117
|
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
|
76
118
|
}
|
77
119
|
|
78
|
-
if (threadIdx.x < rows_per_cuda_block) {
|
120
|
+
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
|
79
121
|
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
80
122
|
}
|
81
123
|
}
|
82
124
|
}
|
83
125
|
|
84
|
-
template <
|
126
|
+
template <ggml_type type>
|
85
127
|
static void mul_mat_vec_q_cuda(
|
86
128
|
const void * vx, const void * vy, float * dst,
|
87
129
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
88
130
|
|
89
|
-
GGML_ASSERT(ncols_x %
|
131
|
+
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
90
132
|
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
91
133
|
|
92
134
|
int id = ggml_cuda_get_device();
|
@@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda(
|
|
124
166
|
|
125
167
|
switch (ncols_y) {
|
126
168
|
case 1:
|
127
|
-
mul_mat_vec_q<1,
|
128
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
169
|
+
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
129
170
|
break;
|
130
171
|
case 2:
|
131
|
-
mul_mat_vec_q<2,
|
132
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
172
|
+
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
133
173
|
break;
|
134
174
|
case 3:
|
135
|
-
mul_mat_vec_q<3,
|
136
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
175
|
+
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
137
176
|
break;
|
138
177
|
case 4:
|
139
|
-
mul_mat_vec_q<4,
|
140
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
178
|
+
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
141
179
|
break;
|
142
180
|
case 5:
|
143
|
-
mul_mat_vec_q<5,
|
144
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
181
|
+
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
145
182
|
break;
|
146
183
|
case 6:
|
147
|
-
mul_mat_vec_q<6,
|
148
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
184
|
+
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
149
185
|
break;
|
150
186
|
case 7:
|
151
|
-
mul_mat_vec_q<7,
|
152
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
187
|
+
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
153
188
|
break;
|
154
189
|
case 8:
|
155
|
-
mul_mat_vec_q<8,
|
156
|
-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
190
|
+
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
157
191
|
break;
|
158
192
|
default:
|
159
193
|
GGML_ASSERT(false);
|
@@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda(
|
|
165
199
|
const void * vx, const void * vy, float * dst,
|
166
200
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
167
201
|
|
168
|
-
mul_mat_vec_q_cuda<
|
169
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
202
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
170
203
|
}
|
171
204
|
|
172
205
|
static void mul_mat_vec_q4_1_q8_1_cuda(
|
173
206
|
const void * vx, const void * vy, float * dst,
|
174
207
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
175
208
|
|
176
|
-
mul_mat_vec_q_cuda<
|
177
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
209
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
178
210
|
}
|
179
211
|
|
180
212
|
static void mul_mat_vec_q5_0_q8_1_cuda(
|
181
213
|
const void * vx, const void * vy, float * dst,
|
182
214
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
183
215
|
|
184
|
-
mul_mat_vec_q_cuda<
|
185
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
216
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
186
217
|
}
|
187
218
|
|
188
219
|
static void mul_mat_vec_q5_1_q8_1_cuda(
|
189
220
|
const void * vx, const void * vy, float * dst,
|
190
221
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
191
222
|
|
192
|
-
mul_mat_vec_q_cuda<
|
193
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
223
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
194
224
|
}
|
195
225
|
|
196
226
|
static void mul_mat_vec_q8_0_q8_1_cuda(
|
197
227
|
const void * vx, const void * vy, float * dst,
|
198
228
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
199
229
|
|
200
|
-
mul_mat_vec_q_cuda<
|
201
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
230
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
202
231
|
}
|
203
232
|
|
204
233
|
static void mul_mat_vec_q2_K_q8_1_cuda(
|
205
234
|
const void * vx, const void * vy, float * dst,
|
206
235
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
207
236
|
|
208
|
-
mul_mat_vec_q_cuda<
|
209
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
237
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
210
238
|
}
|
211
239
|
|
212
240
|
static void mul_mat_vec_q3_K_q8_1_cuda(
|
213
241
|
const void * vx, const void * vy, float * dst,
|
214
242
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
215
243
|
|
216
|
-
mul_mat_vec_q_cuda<
|
217
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
244
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
218
245
|
}
|
219
246
|
|
220
247
|
static void mul_mat_vec_q4_K_q8_1_cuda(
|
221
248
|
const void * vx, const void * vy, float * dst,
|
222
249
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
223
250
|
|
224
|
-
mul_mat_vec_q_cuda<
|
225
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
251
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
226
252
|
}
|
227
253
|
|
228
254
|
static void mul_mat_vec_q5_K_q8_1_cuda(
|
229
255
|
const void * vx, const void * vy, float * dst,
|
230
256
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
231
257
|
|
232
|
-
mul_mat_vec_q_cuda<
|
233
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
258
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
234
259
|
}
|
235
260
|
|
236
261
|
static void mul_mat_vec_q6_K_q8_1_cuda(
|
237
262
|
const void * vx, const void * vy, float * dst,
|
238
263
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
239
264
|
|
240
|
-
mul_mat_vec_q_cuda<
|
241
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
265
|
+
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
242
266
|
}
|
243
267
|
|
244
268
|
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
|
245
269
|
const void * vx, const void * vy, float * dst,
|
246
270
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
247
271
|
|
248
|
-
mul_mat_vec_q_cuda<
|
249
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
272
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
250
273
|
}
|
251
274
|
|
252
275
|
static void mul_mat_vec_iq2_xs_q8_1_cuda(
|
253
276
|
const void * vx, const void * vy, float * dst,
|
254
277
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
255
278
|
|
256
|
-
mul_mat_vec_q_cuda<
|
257
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
279
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
258
280
|
}
|
259
281
|
|
260
282
|
static void mul_mat_vec_iq2_s_q8_1_cuda(
|
261
283
|
const void * vx, const void * vy, float * dst,
|
262
284
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
263
285
|
|
264
|
-
mul_mat_vec_q_cuda<
|
265
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
286
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
266
287
|
}
|
267
288
|
|
268
289
|
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
|
269
290
|
const void * vx, const void * vy, float * dst,
|
270
291
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
271
292
|
|
272
|
-
mul_mat_vec_q_cuda<
|
273
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
293
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
274
294
|
}
|
275
295
|
|
276
296
|
static void mul_mat_vec_iq1_s_q8_1_cuda(
|
277
297
|
const void * vx, const void * vy, float * dst,
|
278
298
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
279
299
|
|
280
|
-
mul_mat_vec_q_cuda<
|
281
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
300
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
282
301
|
}
|
283
302
|
|
284
303
|
static void mul_mat_vec_iq1_m_q8_1_cuda(
|
285
304
|
const void * vx, const void * vy, float * dst,
|
286
305
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
287
306
|
|
288
|
-
mul_mat_vec_q_cuda<
|
289
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
307
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
290
308
|
}
|
291
309
|
|
292
310
|
static void mul_mat_vec_iq4_nl_q8_1_cuda(
|
293
311
|
const void * vx, const void * vy, float * dst,
|
294
312
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
295
313
|
|
296
|
-
mul_mat_vec_q_cuda<
|
297
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
314
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
298
315
|
}
|
299
316
|
|
300
317
|
static void mul_mat_vec_iq4_xs_q8_1_cuda(
|
301
318
|
const void * vx, const void * vy, float * dst,
|
302
319
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
303
320
|
|
304
|
-
mul_mat_vec_q_cuda<
|
305
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
321
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
306
322
|
}
|
307
323
|
|
308
324
|
static void mul_mat_vec_iq3_s_q8_1_cuda(
|
309
325
|
const void * vx, const void * vy, float * dst,
|
310
326
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
311
327
|
|
312
|
-
mul_mat_vec_q_cuda<
|
313
|
-
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
328
|
+
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
314
329
|
}
|
315
330
|
|
316
331
|
void ggml_cuda_op_mul_mat_vec_q(
|
@@ -1,22 +1,23 @@
|
|
1
1
|
#include "quantize.cuh"
|
2
|
+
#include <cstdint>
|
2
3
|
|
3
|
-
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t
|
4
|
-
const int64_t
|
4
|
+
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
|
5
|
+
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
5
6
|
|
6
|
-
if (
|
7
|
+
if (ix0 >= kx0_padded) {
|
7
8
|
return;
|
8
9
|
}
|
9
10
|
|
10
|
-
const int64_t
|
11
|
+
const int64_t ix1 = blockIdx.y;
|
11
12
|
|
12
|
-
const int64_t i_padded =
|
13
|
+
const int64_t i_padded = ix1*kx0_padded + ix0;
|
13
14
|
|
14
15
|
block_q8_1 * y = (block_q8_1 *) vy;
|
15
16
|
|
16
17
|
const int64_t ib = i_padded / QK8_1; // block index
|
17
18
|
const int64_t iqs = i_padded % QK8_1; // quant index
|
18
19
|
|
19
|
-
const float xi =
|
20
|
+
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
|
20
21
|
float amax = fabsf(xi);
|
21
22
|
float sum = xi;
|
22
23
|
|
@@ -36,10 +37,76 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|
36
37
|
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
37
38
|
}
|
38
39
|
|
39
|
-
|
40
|
-
|
41
|
-
const
|
40
|
+
template <bool need_sum>
|
41
|
+
static __global__ void quantize_mmq_q8_1(
|
42
|
+
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
43
|
+
|
44
|
+
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
45
|
+
|
46
|
+
if (ix0 >= kx0_padded) {
|
47
|
+
return;
|
48
|
+
}
|
49
|
+
|
50
|
+
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
51
|
+
|
52
|
+
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
53
|
+
|
54
|
+
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
|
55
|
+
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
56
|
+
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
57
|
+
|
58
|
+
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
|
59
|
+
float amax = fabsf(xi);
|
60
|
+
|
61
|
+
amax = warp_reduce_max(amax);
|
62
|
+
|
63
|
+
float sum;
|
64
|
+
if (need_sum) {
|
65
|
+
sum = warp_reduce_sum(xi);
|
66
|
+
}
|
67
|
+
|
68
|
+
const float d = amax / 127;
|
69
|
+
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
70
|
+
|
71
|
+
y[ib].qs[iqs] = q;
|
72
|
+
|
73
|
+
if (iqs % QK8_1 != 0) {
|
74
|
+
return;
|
75
|
+
}
|
76
|
+
|
77
|
+
if (need_sum) {
|
78
|
+
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
|
79
|
+
} else {
|
80
|
+
((float *) y[ib].ds)[iqs/QK8_1] = d;
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
void quantize_row_q8_1_cuda(
|
85
|
+
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
86
|
+
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
87
|
+
|
88
|
+
GGML_ASSERT(kx0_padded % QK8_1 == 0);
|
89
|
+
|
90
|
+
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
91
|
+
const dim3 num_blocks(block_num_x, kx1*channels, 1);
|
42
92
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
43
|
-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy,
|
93
|
+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
|
94
|
+
|
95
|
+
GGML_UNUSED(type_x);
|
44
96
|
}
|
45
97
|
|
98
|
+
void quantize_mmq_q8_1_cuda(
|
99
|
+
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
|
100
|
+
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
|
101
|
+
|
102
|
+
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
103
|
+
|
104
|
+
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
105
|
+
const dim3 num_blocks(block_num_x, kx1, channels);
|
106
|
+
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
107
|
+
if (mmq_need_sum(type_x)) {
|
108
|
+
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
109
|
+
} else {
|
110
|
+
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
111
|
+
}
|
112
|
+
}
|
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
130
130
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
131
131
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
132
132
|
|
133
|
+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
133
134
|
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
134
135
|
switch (ncols_x) {
|
135
136
|
case 32:
|