llama_cpp 0.15.4 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,686 @@
1
+ #include "convert.cuh"
2
+ #include "dequantize.cuh"
3
+
4
+ #define CUDA_Q8_0_NE_ALIGN 2048
5
+
6
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
7
+ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8
+ const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
9
+
10
+ if (i >= k) {
11
+ return;
12
+ }
13
+
14
+ const int64_t ib = i/qk; // block index
15
+ const int64_t iqs = (i%qk)/qr; // quant index
16
+ const int64_t iybs = i - i%qk; // y block start index
17
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
18
+
19
+ // dequantize
20
+ dfloat2 v;
21
+ dequantize_kernel(vx, ib, iqs, v);
22
+
23
+ y[iybs + iqs + 0] = v.x;
24
+ y[iybs + iqs + y_offset] = v.y;
25
+ }
26
+
27
+ template <bool need_check>
28
+ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
29
+ #if __CUDA_ARCH__ >= CC_PASCAL
30
+ constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
31
+
32
+ const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
33
+ const int * x0 = ((int *) vx) + blockIdx.x * nint;
34
+ half2 * y2 = (half2 *) (y + i0);
35
+
36
+ __shared__ int vals[nint];
37
+
38
+ #pragma unroll
39
+ for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
40
+ if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
41
+ break;
42
+ }
43
+
44
+ const int ix = ix0 + threadIdx.x;
45
+ vals[ix] = x0[ix];
46
+ }
47
+
48
+ __syncthreads();
49
+
50
+ #pragma unroll
51
+ for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
52
+ if (need_check && i0 + iy + 2*threadIdx.x >= k) {
53
+ return;
54
+ }
55
+
56
+ const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
57
+ const half d = *b0;
58
+ const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
59
+
60
+ y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
61
+ }
62
+ #else
63
+ GGML_UNUSED(vx);
64
+ GGML_UNUSED(y);
65
+ GGML_UNUSED(k);
66
+ NO_DEVICE_CODE;
67
+ #endif // __CUDA_ARCH__ >= CC_PASCAL
68
+ }
69
+
70
+ template<typename dst_t>
71
+ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
72
+
73
+ const int64_t i = blockIdx.x;
74
+
75
+ // assume 32 threads
76
+ const int64_t tid = threadIdx.x;
77
+ const int64_t il = tid/8;
78
+ const int64_t ir = tid%8;
79
+ const int64_t ib = 8*i + ir;
80
+ if (ib >= nb32) {
81
+ return;
82
+ }
83
+
84
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
85
+
86
+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
87
+ const float d = __half2float(x->d);
88
+ const float dm = -8*d;
89
+
90
+ const uint8_t * q = x->qs + 4*il;
91
+
92
+ for (int l = 0; l < 4; ++l) {
93
+ y[l+ 0] = d * (q[l] & 0xF) + dm;
94
+ y[l+16] = d * (q[l] >> 4) + dm;
95
+ }
96
+ }
97
+
98
+ template<typename dst_t>
99
+ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
100
+
101
+ const int64_t i = blockIdx.x;
102
+
103
+ // assume 32 threads
104
+ const int64_t tid = threadIdx.x;
105
+ const int64_t il = tid/8;
106
+ const int64_t ir = tid%8;
107
+ const int64_t ib = 8*i + ir;
108
+ if (ib >= nb32) {
109
+ return;
110
+ }
111
+
112
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
113
+
114
+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
115
+ const float2 d = __half22float2(x->dm);
116
+
117
+ const uint8_t * q = x->qs + 4*il;
118
+
119
+ for (int l = 0; l < 4; ++l) {
120
+ y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
121
+ y[l+16] = d.x * (q[l] >> 4) + d.y;
122
+ }
123
+ }
124
+
125
+ //================================== k-quants
126
+
127
+ template<typename dst_t>
128
+ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
129
+
130
+ const int64_t i = blockIdx.x;
131
+ const block_q2_K * x = (const block_q2_K *) vx;
132
+
133
+ const int64_t tid = threadIdx.x;
134
+ const int64_t n = tid/32;
135
+ const int64_t l = tid - 32*n;
136
+ const int64_t is = 8*n + l/16;
137
+
138
+ const uint8_t q = x[i].qs[32*n + l];
139
+ dst_t * y = yy + i*QK_K + 128*n;
140
+
141
+ float dall = __low2half(x[i].dm);
142
+ float dmin = __high2half(x[i].dm);
143
+ y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
144
+ y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
145
+ y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
146
+ y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
147
+ }
148
+
149
+ template<typename dst_t>
150
+ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
151
+
152
+ const int64_t i = blockIdx.x;
153
+ const block_q3_K * x = (const block_q3_K *) vx;
154
+
155
+ const int64_t r = threadIdx.x/4;
156
+ const int64_t tid = r/2;
157
+ const int64_t is0 = r%2;
158
+ const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
159
+ const int64_t n = tid / 4;
160
+ const int64_t j = tid - 4*n;
161
+
162
+ uint8_t m = 1 << (4*n + j);
163
+ int64_t is = 8*n + 2*j + is0;
164
+ int shift = 2*j;
165
+
166
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
167
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
168
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
169
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
170
+ float d_all = x[i].d;
171
+ float dl = d_all * (us - 32);
172
+
173
+ dst_t * y = yy + i*QK_K + 128*n + 32*j;
174
+ const uint8_t * q = x[i].qs + 32*n;
175
+ const uint8_t * hm = x[i].hmask;
176
+
177
+ for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
178
+ }
179
+
180
+ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
181
+ if (j < 4) {
182
+ d = q[j] & 63; m = q[j + 4] & 63;
183
+ } else {
184
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
185
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
186
+ }
187
+ }
188
+
189
+ template<typename dst_t>
190
+ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
191
+ const block_q4_K * x = (const block_q4_K *) vx;
192
+
193
+ const int64_t i = blockIdx.x;
194
+
195
+ // assume 32 threads
196
+ const int64_t tid = threadIdx.x;
197
+ const int64_t il = tid/8;
198
+ const int64_t ir = tid%8;
199
+ const int64_t is = 2*il;
200
+ const int64_t n = 4;
201
+
202
+ dst_t * y = yy + i*QK_K + 64*il + n*ir;
203
+
204
+ const float dall = __low2half(x[i].dm);
205
+ const float dmin = __high2half(x[i].dm);
206
+
207
+ const uint8_t * q = x[i].qs + 32*il + n*ir;
208
+
209
+ uint8_t sc, m;
210
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
211
+ const float d1 = dall * sc; const float m1 = dmin * m;
212
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
213
+ const float d2 = dall * sc; const float m2 = dmin * m;
214
+ for (int l = 0; l < n; ++l) {
215
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
216
+ y[l +32] = d2 * (q[l] >> 4) - m2;
217
+ }
218
+ }
219
+
220
+ template<typename dst_t>
221
+ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
222
+ const block_q5_K * x = (const block_q5_K *) vx;
223
+
224
+ const int64_t i = blockIdx.x;
225
+
226
+ // assume 64 threads - this is very slightly better than the one below
227
+ const int64_t tid = threadIdx.x;
228
+ const int64_t il = tid/16; // il is in 0...3
229
+ const int64_t ir = tid%16; // ir is in 0...15
230
+ const int64_t is = 2*il; // is is in 0...6
231
+
232
+ dst_t * y = yy + i*QK_K + 64*il + 2*ir;
233
+
234
+ const float dall = __low2half(x[i].dm);
235
+ const float dmin = __high2half(x[i].dm);
236
+
237
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
238
+ const uint8_t * qh = x[i].qh + 2*ir;
239
+
240
+ uint8_t sc, m;
241
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
242
+ const float d1 = dall * sc; const float m1 = dmin * m;
243
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
244
+ const float d2 = dall * sc; const float m2 = dmin * m;
245
+
246
+ uint8_t hm = 1 << (2*il);
247
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
248
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
249
+ hm <<= 1;
250
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
251
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
252
+ }
253
+
254
+ template<typename dst_t>
255
+ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
256
+ const block_q6_K * x = (const block_q6_K *) vx;
257
+
258
+ const int64_t i = blockIdx.x;
259
+
260
+ // assume 64 threads - this is very slightly better than the one below
261
+ const int64_t tid = threadIdx.x;
262
+ const int64_t ip = tid/32; // ip is 0 or 1
263
+ const int64_t il = tid - 32*ip; // 0...32
264
+ const int64_t is = 8*ip + il/16;
265
+
266
+ dst_t * y = yy + i*QK_K + 128*ip + il;
267
+
268
+ const float d = x[i].d;
269
+
270
+ const uint8_t * ql = x[i].ql + 64*ip + il;
271
+ const uint8_t qh = x[i].qh[32*ip + il];
272
+ const int8_t * sc = x[i].scales + is;
273
+
274
+ y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
275
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
276
+ y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
277
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
278
+ }
279
+
280
+ template<typename dst_t>
281
+ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
282
+
283
+ const int64_t i = blockIdx.x;
284
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
285
+
286
+ const int64_t tid = threadIdx.x;
287
+ const int64_t il = tid/8; // 0...3
288
+ const int64_t ib = tid%8; // 0...7
289
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
290
+ const uint16_t * q2 = x[i].qs + 4*ib;
291
+ const uint8_t * aux8 = (const uint8_t *)q2;
292
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
293
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
294
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
295
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
296
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
297
+ }
298
+
299
+ template<typename dst_t>
300
+ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
301
+
302
+ const int64_t i = blockIdx.x;
303
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
304
+
305
+ const int64_t tid = threadIdx.x;
306
+ const int64_t il = tid/8; // 0...3
307
+ const int64_t ib = tid%8; // 0...7
308
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
309
+ const uint16_t * q2 = x[i].qs + 4*ib;
310
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
311
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
312
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
313
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
314
+ }
315
+
316
+ template<typename dst_t>
317
+ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
318
+
319
+ const int64_t i = blockIdx.x;
320
+ const block_iq2_s * x = (const block_iq2_s *) vx;
321
+
322
+ const int64_t tid = threadIdx.x;
323
+ const int64_t il = tid/8; // 0...3
324
+ const int64_t ib = tid%8; // 0...7
325
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
326
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
327
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
328
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
329
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
330
+ }
331
+
332
+ template<typename dst_t>
333
+ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
334
+
335
+ const int64_t i = blockIdx.x;
336
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
337
+
338
+ const int64_t tid = threadIdx.x;
339
+ const int64_t il = tid/8; // 0...3
340
+ const int64_t ib = tid%8; // 0...7
341
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
342
+ const uint8_t * q3 = x[i].qs + 8*ib;
343
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
344
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
345
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
346
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
347
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
348
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
349
+ for (int j = 0; j < 4; ++j) {
350
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
351
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
352
+ }
353
+ }
354
+
355
+ template<typename dst_t>
356
+ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
357
+
358
+ const int64_t i = blockIdx.x;
359
+ const block_iq3_s * x = (const block_iq3_s *) vx;
360
+
361
+ const int64_t tid = threadIdx.x;
362
+ const int64_t il = tid/8; // 0...3
363
+ const int64_t ib = tid%8; // 0...7
364
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
365
+ const uint8_t * qs = x[i].qs + 8*ib;
366
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
367
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
368
+ const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
369
+ const uint8_t signs = x[i].signs[4*ib + il];
370
+ for (int j = 0; j < 4; ++j) {
371
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
372
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
373
+ }
374
+ }
375
+
376
+ template<typename dst_t>
377
+ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
378
+
379
+ const int64_t i = blockIdx.x;
380
+ const block_iq1_s * x = (const block_iq1_s *) vx;
381
+
382
+ const int64_t tid = threadIdx.x;
383
+ const int64_t il = tid/8; // 0...3
384
+ const int64_t ib = tid%8; // 0...7
385
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
386
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
387
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
388
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
389
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
390
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
391
+ grid32[0] &= 0x0f0f0f0f;
392
+ for (int j = 0; j < 8; ++j) {
393
+ y[j] = d * (q[j] + delta);
394
+ }
395
+ }
396
+
397
+ template<typename dst_t>
398
+ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
399
+
400
+ const int64_t i = blockIdx.x;
401
+ const block_iq1_m * x = (const block_iq1_m *) vx;
402
+
403
+ const int64_t tid = threadIdx.x;
404
+ const int64_t il = tid/8; // 0...3
405
+ const int64_t ib = tid%8; // 0...7
406
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
407
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
408
+ iq1m_scale_t scale;
409
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
410
+ const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
411
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
412
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
413
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
414
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
415
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
416
+ grid32[0] &= 0x0f0f0f0f;
417
+ for (int j = 0; j < 8; ++j) {
418
+ y[j] = d * (q[j] + delta);
419
+ }
420
+ }
421
+
422
+ template<typename dst_t>
423
+ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
424
+
425
+ const int64_t i = blockIdx.x;
426
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
427
+
428
+ const int64_t tid = threadIdx.x;
429
+ const int64_t il = tid/8; // 0...3
430
+ const int64_t ib = tid%8; // 0...7
431
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
432
+ const uint8_t * q4 = x[ib].qs + 4*il;
433
+ const float d = (float)x[ib].d;
434
+ for (int j = 0; j < 4; ++j) {
435
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
436
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
437
+ }
438
+ }
439
+
440
+ template<typename dst_t>
441
+ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
442
+ const int64_t i = blockIdx.x;
443
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
444
+
445
+ const int64_t tid = threadIdx.x;
446
+ const int64_t il = tid/8; // 0...3
447
+ const int64_t ib = tid%8; // 0...7
448
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
449
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
450
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
451
+ for (int j = 0; j < 4; ++j) {
452
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
453
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
454
+ }
455
+ }
456
+
457
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
458
+ static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
459
+ const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
460
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
461
+ }
462
+
463
+ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
464
+ const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
465
+ if (k % CUDA_Q8_0_NE_ALIGN == 0) {
466
+ const bool need_check = false;
467
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
468
+ } else {
469
+ const bool need_check = true;
470
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
471
+ }
472
+ }
473
+
474
+ template<typename dst_t>
475
+ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
476
+ const int nb = k / QK_K;
477
+ dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
478
+ }
479
+
480
+ template<typename dst_t>
481
+ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
482
+ const int nb = k / QK_K;
483
+ dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
484
+ }
485
+
486
+ template<typename dst_t>
487
+ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
488
+ const int nb32 = k / 32;
489
+ const int nb = (k + 255) / 256;
490
+ dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
491
+ }
492
+
493
+ template<typename dst_t>
494
+ static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
495
+ const int nb32 = k / 32;
496
+ const int nb = (k + 255) / 256;
497
+ dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
498
+ }
499
+
500
+ template<typename dst_t>
501
+ static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
502
+ const int nb = k / QK_K;
503
+ dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
504
+ }
505
+
506
+ template<typename dst_t>
507
+ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
508
+ const int nb = k / QK_K;
509
+ dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
510
+ }
511
+
512
+ template<typename dst_t>
513
+ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
514
+ const int nb = k / QK_K;
515
+ dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
516
+ }
517
+
518
+ template<typename dst_t>
519
+ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
520
+ const int nb = k / QK_K;
521
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
522
+ }
523
+
524
+ template<typename dst_t>
525
+ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
526
+ const int nb = k / QK_K;
527
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
528
+ }
529
+
530
+ template<typename dst_t>
531
+ static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
532
+ const int nb = k / QK_K;
533
+ dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
534
+ }
535
+
536
+ template<typename dst_t>
537
+ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
538
+ const int nb = k / QK_K;
539
+ dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
540
+ }
541
+
542
+ template<typename dst_t>
543
+ static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
544
+ const int nb = k / QK_K;
545
+ dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
546
+ }
547
+
548
+ template<typename dst_t>
549
+ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
550
+ const int nb = k / QK_K;
551
+ dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
552
+ }
553
+
554
+ template<typename dst_t>
555
+ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
556
+ const int nb = (k + QK_K - 1) / QK_K;
557
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
558
+ }
559
+
560
+ template<typename dst_t>
561
+ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
562
+ const int nb = k / QK_K;
563
+ dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
564
+ }
565
+
566
+ template<typename dst_t>
567
+ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
568
+ const int nb = (k + QK_K - 1) / QK_K;
569
+ dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
570
+ }
571
+
572
+ template <typename src_t, typename dst_t>
573
+ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
574
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
575
+
576
+ if (i >= k) {
577
+ return;
578
+ }
579
+
580
+ const src_t * x = (src_t *) vx;
581
+
582
+ y[i] = x[i];
583
+ }
584
+
585
+ template <typename src_t, typename dst_t>
586
+ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
587
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
588
+ convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
589
+ }
590
+
591
+ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
592
+ switch (type) {
593
+ case GGML_TYPE_Q4_0:
594
+ return dequantize_row_q4_0_cuda;
595
+ case GGML_TYPE_Q4_1:
596
+ return dequantize_row_q4_1_cuda;
597
+ case GGML_TYPE_Q5_0:
598
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
599
+ case GGML_TYPE_Q5_1:
600
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
601
+ case GGML_TYPE_Q8_0:
602
+ if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
603
+ return dequantize_block_q8_0_f16_cuda;
604
+ }
605
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
606
+ case GGML_TYPE_Q2_K:
607
+ return dequantize_row_q2_K_cuda;
608
+ case GGML_TYPE_Q3_K:
609
+ return dequantize_row_q3_K_cuda;
610
+ case GGML_TYPE_Q4_K:
611
+ return dequantize_row_q4_K_cuda;
612
+ case GGML_TYPE_Q5_K:
613
+ return dequantize_row_q5_K_cuda;
614
+ case GGML_TYPE_Q6_K:
615
+ return dequantize_row_q6_K_cuda;
616
+ case GGML_TYPE_IQ2_XXS:
617
+ return dequantize_row_iq2_xxs_cuda;
618
+ case GGML_TYPE_IQ2_XS:
619
+ return dequantize_row_iq2_xs_cuda;
620
+ case GGML_TYPE_IQ2_S:
621
+ return dequantize_row_iq2_s_cuda;
622
+ case GGML_TYPE_IQ3_XXS:
623
+ return dequantize_row_iq3_xxs_cuda;
624
+ case GGML_TYPE_IQ1_S:
625
+ return dequantize_row_iq1_s_cuda;
626
+ case GGML_TYPE_IQ1_M:
627
+ return dequantize_row_iq1_m_cuda;
628
+ case GGML_TYPE_IQ4_NL:
629
+ return dequantize_row_iq4_nl_cuda;
630
+ case GGML_TYPE_IQ4_XS:
631
+ return dequantize_row_iq4_xs_cuda;
632
+ case GGML_TYPE_IQ3_S:
633
+ return dequantize_row_iq3_s_cuda;
634
+ case GGML_TYPE_F32:
635
+ return convert_unary_cuda<float>;
636
+ default:
637
+ return nullptr;
638
+ }
639
+ }
640
+
641
+ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
642
+ switch (type) {
643
+ case GGML_TYPE_Q4_0:
644
+ return dequantize_row_q4_0_cuda;
645
+ case GGML_TYPE_Q4_1:
646
+ return dequantize_row_q4_1_cuda;
647
+ case GGML_TYPE_Q5_0:
648
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
649
+ case GGML_TYPE_Q5_1:
650
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
651
+ case GGML_TYPE_Q8_0:
652
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
653
+ case GGML_TYPE_Q2_K:
654
+ return dequantize_row_q2_K_cuda;
655
+ case GGML_TYPE_Q3_K:
656
+ return dequantize_row_q3_K_cuda;
657
+ case GGML_TYPE_Q4_K:
658
+ return dequantize_row_q4_K_cuda;
659
+ case GGML_TYPE_Q5_K:
660
+ return dequantize_row_q5_K_cuda;
661
+ case GGML_TYPE_Q6_K:
662
+ return dequantize_row_q6_K_cuda;
663
+ case GGML_TYPE_IQ2_XXS:
664
+ return dequantize_row_iq2_xxs_cuda;
665
+ case GGML_TYPE_IQ2_XS:
666
+ return dequantize_row_iq2_xs_cuda;
667
+ case GGML_TYPE_IQ2_S:
668
+ return dequantize_row_iq2_s_cuda;
669
+ case GGML_TYPE_IQ3_XXS:
670
+ return dequantize_row_iq3_xxs_cuda;
671
+ case GGML_TYPE_IQ1_S:
672
+ return dequantize_row_iq1_s_cuda;
673
+ case GGML_TYPE_IQ1_M:
674
+ return dequantize_row_iq1_m_cuda;
675
+ case GGML_TYPE_IQ4_NL:
676
+ return dequantize_row_iq4_nl_cuda;
677
+ case GGML_TYPE_IQ4_XS:
678
+ return dequantize_row_iq4_xs_cuda;
679
+ case GGML_TYPE_IQ3_S:
680
+ return dequantize_row_iq3_s_cuda;
681
+ case GGML_TYPE_F16:
682
+ return convert_unary_cuda<half>;
683
+ default:
684
+ return nullptr;
685
+ }
686
+ }