llama_cpp 0.15.4 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -0,0 +1,94 @@
1
+ #include "pool2d.cuh"
2
+
3
+ template <typename Ti, typename To>
4
+ static __global__ void pool2d_nchw_kernel(
5
+ const int ih, const int iw, const int oh, const int ow,
6
+ const int kh, const int kw, const int sh, const int sw,
7
+ const int ph, const int pw, const int parallel_elements,
8
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
9
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
10
+ if (idx >= parallel_elements) {
11
+ return;
12
+ }
13
+
14
+ const int I_HW = ih * iw;
15
+ const int O_HW = oh * ow;
16
+ const int nc = idx / O_HW;
17
+ const int cur_oh = idx % O_HW / ow;
18
+ const int cur_ow = idx % O_HW % ow;
19
+ const Ti* i_ptr = src + nc * I_HW;
20
+ To* o_ptr = dst + nc * O_HW;
21
+ const int start_h = cur_oh * sh - ph;
22
+ const int bh = max(0, start_h);
23
+ const int eh = min(ih, start_h + kh);
24
+ const int start_w = cur_ow * sw - pw;
25
+ const int bw = max(0, start_w);
26
+ const int ew = min(iw, start_w + kw);
27
+ const To scale = 1. / (kh * kw);
28
+ To res = 0;
29
+
30
+ switch (op) {
31
+ case GGML_OP_POOL_AVG: res = 0; break;
32
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
33
+ default: assert(false);
34
+ }
35
+
36
+ for (int i = bh; i < eh; i += 1) {
37
+ for (int j = bw; j < ew; j += 1) {
38
+ #if __CUDA_ARCH__ >= 350
39
+ Ti cur = __ldg(i_ptr + i * iw + j);
40
+ #else
41
+ Ti cur = i_ptr[i * iw + j];
42
+ #endif
43
+ switch (op) {
44
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
45
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
46
+ default: assert(false);
47
+ }
48
+ }
49
+ }
50
+ o_ptr[cur_oh * ow + cur_ow] = res;
51
+ }
52
+
53
+ static void pool2d_nchw_kernel_f32_f32_cuda(
54
+ const int ih, const int iw, const int oh, const int ow,
55
+ const int kh, const int kw, const int sh, const int sw,
56
+ const int ph, const int pw, const int parallel_elements,
57
+ const float * src, float * dst, const enum ggml_op_pool op,
58
+ cudaStream_t stream) {
59
+
60
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
61
+ dim3 block_nums(num_blocks);
62
+ pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
63
+ }
64
+
65
+ void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
66
+ const ggml_tensor * src0 = dst->src[0];
67
+ const float * src0_d = (const float *)src0->data;
68
+ float * dst_d = (float *)dst->data;
69
+ cudaStream_t stream = ctx.stream();
70
+
71
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
72
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
73
+
74
+ const int32_t * opts = (const int32_t *)dst->op_params;
75
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
76
+ const int k0 = opts[1];
77
+ const int k1 = opts[2];
78
+ const int s0 = opts[3];
79
+ const int s1 = opts[4];
80
+ const int p0 = opts[5];
81
+ const int p1 = opts[6];
82
+
83
+ const int64_t IH = src0->ne[1];
84
+ const int64_t IW = src0->ne[0];
85
+
86
+ const int64_t N = dst->ne[3];
87
+ const int64_t OC = dst->ne[2];
88
+ const int64_t OH = dst->ne[1];
89
+ const int64_t OW = dst->ne[0];
90
+
91
+ const int parallel_elements = N * OC * OH * OW;
92
+
93
+ pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
94
+ }
@@ -0,0 +1,45 @@
1
+ #include "quantize.cuh"
2
+
3
+ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) {
4
+ const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
5
+
6
+ if (ix >= kx_padded) {
7
+ return;
8
+ }
9
+
10
+ const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y;
11
+
12
+ const int64_t i_padded = (int64_t)iy*kx_padded + ix;
13
+
14
+ block_q8_1 * y = (block_q8_1 *) vy;
15
+
16
+ const int64_t ib = i_padded / QK8_1; // block index
17
+ const int64_t iqs = i_padded % QK8_1; // quant index
18
+
19
+ const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
20
+ float amax = fabsf(xi);
21
+ float sum = xi;
22
+
23
+ amax = warp_reduce_max(amax);
24
+ sum = warp_reduce_sum(sum);
25
+
26
+ const float d = amax / 127;
27
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
28
+
29
+ y[ib].qs[iqs] = q;
30
+
31
+ if (iqs > 0) {
32
+ return;
33
+ }
34
+
35
+ reinterpret_cast<half&>(y[ib].ds.x) = d;
36
+ reinterpret_cast<half&>(y[ib].ds.y) = sum;
37
+ }
38
+
39
+ void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) {
40
+ const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
41
+ const dim3 num_blocks(block_num_x, ky, 1);
42
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
43
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
44
+ }
45
+
@@ -0,0 +1,271 @@
1
+ #include "rope.cuh"
2
+
3
+ struct rope_corr_dims {
4
+ float v[2];
5
+ };
6
+
7
+ static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
8
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
9
+ return 1.0f - min(1.0f, max(0.0f, y));
10
+ }
11
+
12
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
+ static __device__ void rope_yarn(
15
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
+ float * cos_theta, float * sin_theta) {
17
+ // Get n-d rotational scaling corrected for extrapolation
18
+ float theta_interp = freq_scale * theta_extrap;
19
+ float theta = theta_interp;
20
+ if (ext_factor != 0.0f) {
21
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
22
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
23
+
24
+ // Get n-d magnitude scaling corrected for interpolation
25
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
26
+ }
27
+ *cos_theta = cosf(theta) * mscale;
28
+ *sin_theta = sinf(theta) * mscale;
29
+ }
30
+
31
+ template<typename T, bool has_ff>
32
+ static __global__ void rope_norm(
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
35
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
36
+
37
+ if (i0 >= ne0) {
38
+ return;
39
+ }
40
+
41
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
42
+
43
+ if (i0 >= n_dims) {
44
+ const int i = row*ne0 + i0;
45
+
46
+ dst[i + 0] = x[i + 0];
47
+ dst[i + 1] = x[i + 1];
48
+
49
+ return;
50
+ }
51
+
52
+ const int i = row*ne0 + i0;
53
+ const int i2 = row/p_delta_rows;
54
+
55
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
56
+
57
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
58
+
59
+ float cos_theta;
60
+ float sin_theta;
61
+
62
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
63
+
64
+ const float x0 = x[i + 0];
65
+ const float x1 = x[i + 1];
66
+
67
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
68
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
69
+ }
70
+
71
+ template<typename T, bool has_ff>
72
+ static __global__ void rope_neox(
73
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
74
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
75
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
76
+
77
+ if (i0 >= ne0) {
78
+ return;
79
+ }
80
+
81
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
82
+
83
+ if (i0 >= n_dims) {
84
+ const int i = row*ne0 + i0;
85
+
86
+ dst[i + 0] = x[i + 0];
87
+ dst[i + 1] = x[i + 1];
88
+
89
+ return;
90
+ }
91
+
92
+ const int i = row*ne0 + i0/2;
93
+ const int i2 = row/p_delta_rows;
94
+
95
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
96
+
97
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
98
+
99
+ float cos_theta;
100
+ float sin_theta;
101
+
102
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
103
+
104
+ const float x0 = x[i + 0];
105
+ const float x1 = x[i + n_dims/2];
106
+
107
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
108
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
109
+ }
110
+
111
+ template<typename T>
112
+ static void rope_norm_cuda(
113
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
114
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
115
+ GGML_ASSERT(ne0 % 2 == 0);
116
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
117
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
118
+ const dim3 block_nums(nr, n_blocks_x, 1);
119
+
120
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
121
+
122
+ if (freq_factors == nullptr) {
123
+ rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
124
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
125
+ theta_scale, freq_factors
126
+ );
127
+ } else {
128
+ rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
129
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
130
+ theta_scale, freq_factors
131
+ );
132
+ }
133
+ }
134
+
135
+ template<typename T>
136
+ static void rope_neox_cuda(
137
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
138
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
139
+ GGML_ASSERT(ne0 % 2 == 0);
140
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
141
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
142
+ const dim3 block_nums(nr, n_blocks_x, 1);
143
+
144
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
145
+
146
+ if (freq_factors == nullptr) {
147
+ rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
148
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
149
+ theta_scale, freq_factors
150
+ );
151
+ } else {
152
+ rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
153
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
154
+ theta_scale, freq_factors
155
+ );
156
+ }
157
+ }
158
+
159
+ static void rope_norm_cuda_f16(
160
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
161
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
162
+
163
+ rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
164
+ }
165
+
166
+ static void rope_norm_cuda_f32(
167
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
168
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
169
+
170
+ rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
171
+ }
172
+
173
+ static void rope_neox_cuda_f16(
174
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
175
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
176
+
177
+ rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
178
+ }
179
+
180
+ static void rope_neox_cuda_f32(
181
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
182
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
183
+ ) {
184
+
185
+ rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
186
+ }
187
+
188
+ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
189
+ const ggml_tensor * src0 = dst->src[0];
190
+ const ggml_tensor * src1 = dst->src[1];
191
+ const ggml_tensor * src2 = dst->src[2];
192
+
193
+ const float * src0_d = (const float *)src0->data;
194
+ const float * src1_d = (const float *)src1->data;
195
+
196
+ float * dst_d = (float *)dst->data;
197
+ cudaStream_t stream = ctx.stream();
198
+
199
+ GGML_ASSERT(ggml_is_contiguous(src0));
200
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
201
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202
+ GGML_ASSERT(src0->type == dst->type);
203
+
204
+ const int64_t ne00 = src0->ne[0];
205
+ const int64_t ne01 = src0->ne[1];
206
+ const int64_t nr = ggml_nrows(src0);
207
+
208
+ //const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params)[1];
210
+ const int mode = ((int32_t *) dst->op_params)[2];
211
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
213
+
214
+ // RoPE alteration for extended context
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
222
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
223
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
224
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
225
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
+
229
+ const bool is_neox = mode & 2;
230
+
231
+ const int32_t * pos = (const int32_t *) src1_d;
232
+
233
+ const float * freq_factors = nullptr;
234
+ if (src2 != nullptr) {
235
+ freq_factors = (const float *) src2->data;
236
+ }
237
+
238
+ rope_corr_dims corr_dims;
239
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
+
241
+ // compute
242
+ if (is_neox) {
243
+ if (src0->type == GGML_TYPE_F32) {
244
+ rope_neox_cuda_f32(
245
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
+ attn_factor, corr_dims, freq_factors, stream
247
+ );
248
+ } else if (src0->type == GGML_TYPE_F16) {
249
+ rope_neox_cuda_f16(
250
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
+ attn_factor, corr_dims, freq_factors, stream
252
+ );
253
+ } else {
254
+ GGML_ASSERT(false);
255
+ }
256
+ } else {
257
+ if (src0->type == GGML_TYPE_F32) {
258
+ rope_norm_cuda_f32(
259
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, stream
261
+ );
262
+ } else if (src0->type == GGML_TYPE_F16) {
263
+ rope_norm_cuda_f16(
264
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, stream
266
+ );
267
+ } else {
268
+ GGML_ASSERT(false);
269
+ }
270
+ }
271
+ }
@@ -0,0 +1,31 @@
1
+ #include "scale.cuh"
2
+
3
+ static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
4
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
5
+
6
+ if (i >= k) {
7
+ return;
8
+ }
9
+
10
+ dst[i] = scale * x[i];
11
+ }
12
+
13
+ static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
14
+ const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
15
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
16
+ }
17
+
18
+ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19
+ const ggml_tensor * src0 = dst->src[0];
20
+ const float * src0_d = (const float *)src0->data;
21
+ float * dst_d = (float *)dst->data;
22
+ cudaStream_t stream = ctx.stream();
23
+
24
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
25
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
26
+
27
+ float scale;
28
+ memcpy(&scale, dst->op_params, sizeof(float));
29
+
30
+ scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
31
+ }
@@ -0,0 +1,205 @@
1
+ #include "common.cuh"
2
+ #include "softmax.cuh"
3
+
4
+ template <typename T>
5
+ static __device__ __forceinline__ float t2f32(T val) {
6
+ return (float) val;
7
+ }
8
+
9
+ template <>
10
+ __device__ float __forceinline__ t2f32<half>(half val) {
11
+ return __half2float(val);
12
+ }
13
+
14
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
15
+ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
16
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
17
+
18
+ const int tid = threadIdx.x;
19
+ const int rowx = blockIdx.x;
20
+ const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
21
+
22
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
23
+
24
+ const int warp_id = threadIdx.x / WARP_SIZE;
25
+ const int lane_id = threadIdx.x % WARP_SIZE;
26
+
27
+ const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
28
+
29
+ extern __shared__ float data_soft_max_f32[];
30
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
31
+ // shared memory buffer to cache values between iterations:
32
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
33
+
34
+ float max_val = -INFINITY;
35
+
36
+ #pragma unroll
37
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
38
+ const int col = col0 + tid;
39
+
40
+ if (ncols_template == 0 && col >= ncols) {
41
+ break;
42
+ }
43
+
44
+ const int64_t ix = (int64_t)rowx*ncols + col;
45
+ const int64_t iy = (int64_t)rowy*ncols + col;
46
+
47
+ const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
48
+
49
+ vals[col] = val;
50
+ max_val = max(max_val, val);
51
+ }
52
+
53
+ // find the max value in the block
54
+ max_val = warp_reduce_max(max_val);
55
+ if (block_size > WARP_SIZE) {
56
+ if (warp_id == 0) {
57
+ buf_iw[lane_id] = -INFINITY;
58
+ }
59
+ __syncthreads();
60
+
61
+ if (lane_id == 0) {
62
+ buf_iw[warp_id] = max_val;
63
+ }
64
+ __syncthreads();
65
+
66
+ max_val = buf_iw[lane_id];
67
+ max_val = warp_reduce_max(max_val);
68
+ }
69
+
70
+ float tmp = 0.0f; // partial sum
71
+
72
+ #pragma unroll
73
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
74
+ const int col = col0 + tid;
75
+
76
+ if (ncols_template == 0 && col >= ncols) {
77
+ break;
78
+ }
79
+
80
+ const float val = expf(vals[col] - max_val);
81
+ tmp += val;
82
+ vals[col] = val;
83
+ }
84
+
85
+ // find the sum of exps in the block
86
+ tmp = warp_reduce_sum(tmp);
87
+ if (block_size > WARP_SIZE) {
88
+ __syncthreads();
89
+ if (warp_id == 0) {
90
+ buf_iw[lane_id] = 0.0f;
91
+ }
92
+ __syncthreads();
93
+
94
+ if (lane_id == 0) {
95
+ buf_iw[warp_id] = tmp;
96
+ }
97
+ __syncthreads();
98
+
99
+ tmp = buf_iw[lane_id];
100
+ tmp = warp_reduce_sum(tmp);
101
+ }
102
+
103
+ const float inv_sum = 1.0f / tmp;
104
+
105
+ #pragma unroll
106
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
107
+ const int col = col0 + tid;
108
+
109
+ if (ncols_template == 0 && col >= ncols) {
110
+ return;
111
+ }
112
+
113
+ const int64_t idst = (int64_t)rowx*ncols + col;
114
+ dst[idst] = vals[col] * inv_sum;
115
+ }
116
+ }
117
+
118
+ template<typename T>
119
+ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
120
+ int nth = WARP_SIZE;
121
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
122
+ const dim3 block_dims(nth, 1, 1);
123
+ const dim3 block_nums(nrows_x, 1, 1);
124
+ const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
125
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
126
+
127
+ const uint32_t n_head = nrows_x/nrows_y;
128
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
129
+
130
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
131
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
+
133
+ if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
134
+ switch (ncols_x) {
135
+ case 32:
136
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
137
+ break;
138
+ case 64:
139
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
140
+ break;
141
+ case 128:
142
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
143
+ break;
144
+ case 256:
145
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
146
+ break;
147
+ case 512:
148
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
149
+ break;
150
+ case 1024:
151
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
152
+ break;
153
+ case 2048:
154
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
155
+ break;
156
+ case 4096:
157
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
158
+ break;
159
+ default:
160
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
161
+ break;
162
+ }
163
+ } else {
164
+ const size_t shmem_low = WARP_SIZE*sizeof(float);
165
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
166
+ }
167
+ }
168
+
169
+ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
170
+ const ggml_tensor * src0 = dst->src[0];
171
+ const ggml_tensor * src1 = dst->src[1];
172
+
173
+ const float * src0_d = (const float *)src0->data;
174
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
175
+
176
+ float * dst_d = (float *)dst->data;
177
+ cudaStream_t stream = ctx.stream();
178
+
179
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
180
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
181
+
182
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
183
+
184
+ const int64_t ne00 = src0->ne[0];
185
+ const int64_t nrows_x = ggml_nrows(src0);
186
+ const int64_t nrows_y = src0->ne[1];
187
+
188
+ float scale = 1.0f;
189
+ float max_bias = 0.0f;
190
+
191
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
192
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
193
+
194
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
195
+
196
+ if (use_f16) {
197
+ const half * src1_dd = (const half *)src1_d;
198
+
199
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
200
+ } else {
201
+ const float * src1_dd = (const float *)src1_d;
202
+
203
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
204
+ }
205
+ }
@@ -0,0 +1,40 @@
1
+ #include "sumrows.cuh"
2
+
3
+ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
4
+ const int row = blockIdx.x;
5
+ const int col = threadIdx.x;
6
+
7
+ float sum = 0.0f;
8
+ for (int i = col; i < ncols; i += blockDim.x) {
9
+ sum += x[row * ncols + i];
10
+ }
11
+
12
+ sum = warp_reduce_sum(sum);
13
+
14
+ if (col == 0) {
15
+ dst[row] = sum;
16
+ }
17
+ }
18
+
19
+ static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
20
+ const dim3 block_dims(WARP_SIZE, 1, 1);
21
+ const dim3 block_nums(nrows, 1, 1);
22
+ k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
23
+ }
24
+
25
+ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
26
+ const ggml_tensor * src0 = dst->src[0];
27
+ const float * src0_d = (const float *)src0->data;
28
+ float * dst_d = (float *)dst->data;
29
+ cudaStream_t stream = ctx.stream();
30
+
31
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
32
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
33
+ GGML_ASSERT(ggml_is_contiguous(src0));
34
+
35
+
36
+ const int64_t ncols = src0->ne[0];
37
+ const int64_t nrows = ggml_nrows(src0);
38
+
39
+ sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
40
+ }
@@ -0,0 +1,5 @@
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);