llama_cpp 0.15.4 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,490 @@
1
+ #include "cpy.cuh"
2
+
3
+ typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
4
+
5
+ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
6
+ const float * xi = (const float *) cxi;
7
+ float * dsti = (float *) cdsti;
8
+
9
+ *dsti = *xi;
10
+ }
11
+
12
+ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
13
+ const float * xi = (const float *) cxi;
14
+ half * dsti = (half *) cdsti;
15
+
16
+ *dsti = __float2half(*xi);
17
+ }
18
+
19
+ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
20
+ const half * xi = (const half *) cxi;
21
+ half * dsti = (half *) cdsti;
22
+
23
+ *dsti = *xi;
24
+ }
25
+
26
+ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
27
+ const half * xi = (const half *) cxi;
28
+ float * dsti = (float *) cdsti;
29
+
30
+ *dsti = *xi;
31
+ }
32
+
33
+ template <cpy_kernel_t cpy_1>
34
+ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
35
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
36
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
37
+ const int nb12, const int nb13) {
38
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
39
+
40
+ if (i >= ne) {
41
+ return;
42
+ }
43
+
44
+ // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
45
+ // then combine those indices with the corresponding byte offsets to get the total offsets
46
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
47
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
48
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
49
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
50
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
51
+
52
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
53
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
54
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
55
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
56
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
57
+
58
+ cpy_1(cx + x_offset, cdst + dst_offset);
59
+ }
60
+
61
+ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
62
+ const float * xi = (const float *) cxi;
63
+ block_q8_0 * dsti = (block_q8_0 *) cdsti;
64
+
65
+ float amax = 0.0f; // absolute max
66
+
67
+ for (int j = 0; j < QK8_0; j++) {
68
+ const float v = xi[j];
69
+ amax = fmaxf(amax, fabsf(v));
70
+ }
71
+
72
+ const float d = amax / ((1 << 7) - 1);
73
+ const float id = d ? 1.0f/d : 0.0f;
74
+
75
+ dsti->d = d;
76
+
77
+ for (int j = 0; j < QK8_0; ++j) {
78
+ const float x0 = xi[j]*id;
79
+
80
+ dsti->qs[j] = roundf(x0);
81
+ }
82
+ }
83
+
84
+ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
85
+ const float * xi = (const float *) cxi;
86
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
87
+
88
+ float amax = 0.0f;
89
+ float vmax = 0.0f;
90
+
91
+ for (int j = 0; j < QK4_0; ++j) {
92
+ const float v = xi[j];
93
+ if (amax < fabsf(v)) {
94
+ amax = fabsf(v);
95
+ vmax = v;
96
+ }
97
+ }
98
+
99
+ const float d = vmax / -8;
100
+ const float id = d ? 1.0f/d : 0.0f;
101
+
102
+ dsti->d = d;
103
+
104
+ for (int j = 0; j < QK4_0/2; ++j) {
105
+ const float x0 = xi[0 + j]*id;
106
+ const float x1 = xi[QK4_0/2 + j]*id;
107
+
108
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
109
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
110
+
111
+ dsti->qs[j] = xi0;
112
+ dsti->qs[j] |= xi1 << 4;
113
+ }
114
+ }
115
+
116
+ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
117
+ const float * xi = (const float *) cxi;
118
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
119
+
120
+ float vmin = FLT_MAX;
121
+ float vmax = -FLT_MAX;
122
+
123
+ for (int j = 0; j < QK4_1; ++j) {
124
+ const float v = xi[j];
125
+
126
+ if (v < vmin) vmin = v;
127
+ if (v > vmax) vmax = v;
128
+ }
129
+
130
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
131
+ const float id = d ? 1.0f/d : 0.0f;
132
+
133
+ dsti->dm.x = d;
134
+ dsti->dm.y = vmin;
135
+
136
+ for (int j = 0; j < QK4_1/2; ++j) {
137
+ const float x0 = (xi[0 + j] - vmin)*id;
138
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
139
+
140
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
141
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
142
+
143
+ dsti->qs[j] = xi0;
144
+ dsti->qs[j] |= xi1 << 4;
145
+ }
146
+ }
147
+
148
+ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
149
+ const float * xi = (const float *) cxi;
150
+ block_q5_0 * dsti = (block_q5_0 *) cdsti;
151
+
152
+ float amax = 0.0f;
153
+ float vmax = 0.0f;
154
+
155
+ for (int j = 0; j < QK5_0; ++j) {
156
+ const float v = xi[j];
157
+ if (amax < fabsf(v)) {
158
+ amax = fabsf(v);
159
+ vmax = v;
160
+ }
161
+ }
162
+
163
+ const float d = vmax / -16;
164
+ const float id = d ? 1.0f/d : 0.0f;
165
+
166
+ dsti->d = d;
167
+
168
+ uint32_t qh = 0;
169
+ for (int j = 0; j < QK5_0/2; ++j) {
170
+ const float x0 = xi[0 + j]*id;
171
+ const float x1 = xi[QK5_0/2 + j]*id;
172
+
173
+ const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
174
+ const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
175
+
176
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
177
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
178
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
179
+ }
180
+ memcpy(dsti->qh, &qh, sizeof(qh));
181
+ }
182
+
183
+ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
184
+ const float * xi = (const float *) cxi;
185
+ block_q5_1 * dsti = (block_q5_1 *) cdsti;
186
+
187
+ float min = xi[0];
188
+ float max = xi[0];
189
+
190
+ for (int j = 1; j < QK5_1; ++j) {
191
+ const float v = xi[j];
192
+ min = v < min ? v : min;
193
+ max = v > max ? v : max;
194
+ }
195
+
196
+ const float d = (max - min) / 31;
197
+ const float id = d ? 1.0f/d : 0.0f;
198
+
199
+ dsti->dm.x = d;
200
+ dsti->dm.y = min;
201
+
202
+ uint32_t qh = 0;
203
+ for (int j = 0; j < QK5_1/2; ++j) {
204
+ const float x0 = (xi[0 + j] - min)*id;
205
+ const float x1 = (xi[QK5_1/2 + j] - min)*id;
206
+
207
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
208
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
209
+
210
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
211
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
212
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
213
+ }
214
+ memcpy(dsti->qh, &qh, sizeof(qh));
215
+ }
216
+
217
+
218
+ static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
219
+ if (x <= val[0]) return 0;
220
+ if (x >= val[n-1]) return n-1;
221
+ int ml = 0, mu = n-1;
222
+ while (mu-ml > 1) {
223
+ int mav = (ml+mu)/2;
224
+ if (x < val[mav]) mu = mav; else ml = mav;
225
+ }
226
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
227
+ }
228
+
229
+ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
230
+ const float * xi = (const float *) cxi;
231
+ block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
232
+
233
+ float amax = 0.0f;
234
+ float vmax = 0.0f;
235
+
236
+ for (int j = 0; j < QK4_NL; ++j) {
237
+ const float v = xi[j];
238
+ if (amax < fabsf(v)) {
239
+ amax = fabsf(v);
240
+ vmax = v;
241
+ }
242
+ }
243
+
244
+ float d = vmax / kvalues_iq4nl[0];
245
+ const float id = d ? 1.0f/d : 0.0f;
246
+
247
+ float sumqx = 0, sumq2 = 0;
248
+ for (int j = 0; j < QK4_NL/2; ++j) {
249
+ const float x0 = xi[0 + j]*id;
250
+ const float x1 = xi[QK4_NL/2 + j]*id;
251
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
252
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
253
+ dsti->qs[j] = xi0 | (xi1 << 4);
254
+ const float v0 = kvalues_iq4nl[xi0];
255
+ const float v1 = kvalues_iq4nl[xi1];
256
+ const float w0 = xi[0 + j]*xi[0 + j];
257
+ const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
258
+ sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
259
+ sumq2 += w0*v0*v0 + w1*v1*v1;
260
+ }
261
+
262
+ dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
263
+ }
264
+
265
+ template <cpy_kernel_t cpy_blck, int qk>
266
+ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
267
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
268
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
269
+ const int nb12, const int nb13) {
270
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
271
+
272
+ if (i >= ne) {
273
+ return;
274
+ }
275
+
276
+ const int i03 = i/(ne00 * ne01 * ne02);
277
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
278
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
279
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
280
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
281
+
282
+ const int i13 = i/(ne10 * ne11 * ne12);
283
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
284
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
285
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
286
+ const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
287
+
288
+ cpy_blck(cx + x_offset, cdst + dst_offset);
289
+ }
290
+
291
+ static void ggml_cpy_f16_f32_cuda(
292
+ const char * cx, char * cdst, const int ne,
293
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
294
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
295
+
296
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
297
+ cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
298
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
299
+ }
300
+
301
+ static void ggml_cpy_f32_f32_cuda(
302
+ const char * cx, char * cdst, const int ne,
303
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
304
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
305
+
306
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
307
+ cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
308
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
309
+ }
310
+
311
+ static void ggml_cpy_f32_f16_cuda(
312
+ const char * cx, char * cdst, const int ne,
313
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
314
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
315
+
316
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
317
+ cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
318
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
319
+ }
320
+
321
+ static void ggml_cpy_f32_q8_0_cuda(
322
+ const char * cx, char * cdst, const int ne,
323
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
324
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
325
+
326
+ GGML_ASSERT(ne % QK8_0 == 0);
327
+ const int num_blocks = ne / QK8_0;
328
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
329
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
330
+ }
331
+
332
+ static void ggml_cpy_f32_q4_0_cuda(
333
+ const char * cx, char * cdst, const int ne,
334
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
335
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
336
+
337
+ GGML_ASSERT(ne % QK4_0 == 0);
338
+ const int num_blocks = ne / QK4_0;
339
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
340
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
341
+ }
342
+
343
+ static void ggml_cpy_f32_q4_1_cuda(
344
+ const char * cx, char * cdst, const int ne,
345
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
346
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
347
+
348
+ GGML_ASSERT(ne % QK4_1 == 0);
349
+ const int num_blocks = ne / QK4_1;
350
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
351
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
352
+ }
353
+
354
+ static void ggml_cpy_f32_q5_0_cuda(
355
+ const char * cx, char * cdst, const int ne,
356
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
357
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
358
+
359
+ GGML_ASSERT(ne % QK5_0 == 0);
360
+ const int num_blocks = ne / QK5_0;
361
+ cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
362
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
363
+ }
364
+
365
+ static void ggml_cpy_f32_q5_1_cuda(
366
+ const char * cx, char * cdst, const int ne,
367
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
368
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
369
+
370
+ GGML_ASSERT(ne % QK5_1 == 0);
371
+ const int num_blocks = ne / QK5_1;
372
+ cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
373
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
374
+ }
375
+
376
+ static void ggml_cpy_f32_iq4_nl_cuda(
377
+ const char * cx, char * cdst, const int ne,
378
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
379
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
380
+
381
+ GGML_ASSERT(ne % QK4_NL == 0);
382
+ const int num_blocks = ne / QK4_NL;
383
+ cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
384
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
385
+ }
386
+
387
+ static void ggml_cpy_f16_f16_cuda(
388
+ const char * cx, char * cdst, const int ne,
389
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
390
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
391
+
392
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
393
+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
394
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
395
+ }
396
+
397
+ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
398
+ const int64_t ne = ggml_nelements(src0);
399
+ GGML_ASSERT(ne == ggml_nelements(src1));
400
+
401
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
402
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
403
+
404
+ const int64_t ne00 = src0->ne[0];
405
+ const int64_t ne01 = src0->ne[1];
406
+ const int64_t ne02 = src0->ne[2];
407
+
408
+ //GGML_ASSERT(src0->ne[3] == 1);
409
+
410
+ const int64_t nb00 = src0->nb[0];
411
+ const int64_t nb01 = src0->nb[1];
412
+ const int64_t nb02 = src0->nb[2];
413
+ const int64_t nb03 = src0->nb[3];
414
+
415
+ const int64_t ne10 = src1->ne[0];
416
+ const int64_t ne11 = src1->ne[1];
417
+ const int64_t ne12 = src1->ne[2];
418
+
419
+ //GGML_ASSERT(src1->ne[3] == 1);
420
+
421
+ const int64_t nb10 = src1->nb[0];
422
+ const int64_t nb11 = src1->nb[1];
423
+ const int64_t nb12 = src1->nb[2];
424
+ const int64_t nb13 = src1->nb[3];
425
+
426
+ cudaStream_t main_stream = ctx.stream();
427
+
428
+ char * src0_ddc = (char *) src0->data;
429
+ char * src1_ddc = (char *) src1->data;
430
+
431
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432
+ ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434
+ ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
435
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
436
+ ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
437
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
438
+ ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
439
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
440
+ ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
441
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
442
+ ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
443
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
444
+ ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
445
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
446
+ ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
447
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
448
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
449
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
450
+ ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
451
+ } else {
452
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
453
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
454
+ GGML_ASSERT(false);
455
+ }
456
+ }
457
+
458
+ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
459
+ const ggml_tensor * src0 = dst->src[0];
460
+ ggml_cuda_cpy(ctx, src0, dst);
461
+ }
462
+
463
+ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
464
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465
+ return (void*) cpy_f32_f16<cpy_1_f32_f32>;
466
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
468
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469
+ return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477
+ return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
482
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483
+ return (void*) cpy_f32_f16<cpy_1_f16_f32>;
484
+ } else {
485
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
486
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
487
+ GGML_ASSERT(false);
488
+ }
489
+ }
490
+
@@ -0,0 +1,40 @@
1
+ #include "diagmask.cuh"
2
+
3
+ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
4
+ const int col = blockDim.y*blockIdx.y + threadIdx.y;
5
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
6
+
7
+ if (col >= ncols) {
8
+ return;
9
+ }
10
+
11
+ const int i = row*ncols + col;
12
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
13
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
14
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
15
+ }
16
+
17
+ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
18
+ const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
19
+ const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
20
+ const dim3 block_nums(nrows_x, block_num_x, 1);
21
+ diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
22
+ }
23
+
24
+ void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
25
+ const ggml_tensor * src0 = dst->src[0];
26
+ const float * src0_d = (const float *)src0->data;
27
+ float * dst_d = (float *)dst->data;
28
+ cudaStream_t stream = ctx.stream();
29
+
30
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
31
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
32
+
33
+ const int64_t ne00 = src0->ne[0];
34
+ const int64_t ne01 = src0->ne[1];
35
+ const int nrows0 = ggml_nrows(src0);
36
+
37
+ const int n_past = ((int32_t *) dst->op_params)[0];
38
+
39
+ diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
40
+ }