llama_cpp 0.16.2 → 0.17.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +7 -12
- data/ext/llama_cpp/extconf.rb +2 -43
- data/ext/llama_cpp/llama_cpp.cpp +8 -0
- data/lib/llama_cpp/version.rb +3 -3
- data/sig/llama_cpp.rbs +3 -0
- metadata +2 -171
- data/vendor/include/.gitkeep +0 -0
- data/vendor/lib/.gitkeep +0 -0
- data/vendor/tmp/llama.cpp/LICENSE +0 -21
- data/vendor/tmp/llama.cpp/Makefile +0 -1124
- data/vendor/tmp/llama.cpp/ggml-alloc.c +0 -1041
- data/vendor/tmp/llama.cpp/ggml-alloc.h +0 -76
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +0 -153
- data/vendor/tmp/llama.cpp/ggml-backend.c +0 -2225
- data/vendor/tmp/llama.cpp/ggml-backend.h +0 -236
- data/vendor/tmp/llama.cpp/ggml-blas.cpp +0 -363
- data/vendor/tmp/llama.cpp/ggml-blas.h +0 -23
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -1805
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +0 -47
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +0 -34
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +0 -104
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +0 -280
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +0 -34
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +0 -196
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +0 -686
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +0 -490
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +0 -40
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +0 -674
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +0 -319
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +0 -312
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +0 -345
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +0 -178
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +0 -104
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +0 -88
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +0 -419
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +0 -221
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +0 -49
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +0 -94
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +0 -112
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +0 -271
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +0 -31
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +0 -206
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +0 -40
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q2_k.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q3_k.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q4_k.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_1.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q5_k.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q6_k.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/mmq-instance-q8_0.cu +0 -5
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +0 -47
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +0 -314
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +0 -51
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +0 -3069
- data/vendor/tmp/llama.cpp/ggml-cuda.h +0 -44
- data/vendor/tmp/llama.cpp/ggml-impl.h +0 -651
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -2038
- data/vendor/tmp/llama.cpp/ggml-kompute.h +0 -46
- data/vendor/tmp/llama.cpp/ggml-metal.h +0 -66
- data/vendor/tmp/llama.cpp/ggml-metal.m +0 -3273
- data/vendor/tmp/llama.cpp/ggml-metal.metal +0 -6540
- data/vendor/tmp/llama.cpp/ggml-quants.c +0 -14994
- data/vendor/tmp/llama.cpp/ggml-quants.h +0 -133
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +0 -1178
- data/vendor/tmp/llama.cpp/ggml-rpc.h +0 -24
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +0 -6351
- data/vendor/tmp/llama.cpp/ggml-sycl.h +0 -40
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +0 -144508
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +0 -7183
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -29
- data/vendor/tmp/llama.cpp/ggml.c +0 -22506
- data/vendor/tmp/llama.cpp/ggml.h +0 -2458
- data/vendor/tmp/llama.cpp/llama.cpp +0 -18985
- data/vendor/tmp/llama.cpp/llama.h +0 -1147
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +0 -38
- data/vendor/tmp/llama.cpp/sgemm.cpp +0 -1032
- data/vendor/tmp/llama.cpp/sgemm.h +0 -14
- data/vendor/tmp/llama.cpp/unicode-data.cpp +0 -7033
- data/vendor/tmp/llama.cpp/unicode-data.h +0 -20
- data/vendor/tmp/llama.cpp/unicode.cpp +0 -810
- data/vendor/tmp/llama.cpp/unicode.h +0 -63
@@ -1,271 +0,0 @@
|
|
1
|
-
#include "rope.cuh"
|
2
|
-
|
3
|
-
struct rope_corr_dims {
|
4
|
-
float v[2];
|
5
|
-
};
|
6
|
-
|
7
|
-
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
8
|
-
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
9
|
-
return 1.0f - min(1.0f, max(0.0f, y));
|
10
|
-
}
|
11
|
-
|
12
|
-
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
13
|
-
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
14
|
-
static __device__ void rope_yarn(
|
15
|
-
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
16
|
-
float * cos_theta, float * sin_theta) {
|
17
|
-
// Get n-d rotational scaling corrected for extrapolation
|
18
|
-
float theta_interp = freq_scale * theta_extrap;
|
19
|
-
float theta = theta_interp;
|
20
|
-
if (ext_factor != 0.0f) {
|
21
|
-
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
22
|
-
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
23
|
-
|
24
|
-
// Get n-d magnitude scaling corrected for interpolation
|
25
|
-
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
26
|
-
}
|
27
|
-
*cos_theta = cosf(theta) * mscale;
|
28
|
-
*sin_theta = sinf(theta) * mscale;
|
29
|
-
}
|
30
|
-
|
31
|
-
template<typename T, bool has_ff>
|
32
|
-
static __global__ void rope_norm(
|
33
|
-
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
34
|
-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
35
|
-
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
36
|
-
|
37
|
-
if (i0 >= ne0) {
|
38
|
-
return;
|
39
|
-
}
|
40
|
-
|
41
|
-
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
42
|
-
|
43
|
-
if (i0 >= n_dims) {
|
44
|
-
const int i = row*ne0 + i0;
|
45
|
-
|
46
|
-
dst[i + 0] = x[i + 0];
|
47
|
-
dst[i + 1] = x[i + 1];
|
48
|
-
|
49
|
-
return;
|
50
|
-
}
|
51
|
-
|
52
|
-
const int i = row*ne0 + i0;
|
53
|
-
const int i2 = row/p_delta_rows;
|
54
|
-
|
55
|
-
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
56
|
-
|
57
|
-
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
58
|
-
|
59
|
-
float cos_theta;
|
60
|
-
float sin_theta;
|
61
|
-
|
62
|
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
63
|
-
|
64
|
-
const float x0 = x[i + 0];
|
65
|
-
const float x1 = x[i + 1];
|
66
|
-
|
67
|
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
68
|
-
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
69
|
-
}
|
70
|
-
|
71
|
-
template<typename T, bool has_ff>
|
72
|
-
static __global__ void rope_neox(
|
73
|
-
const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
74
|
-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
|
75
|
-
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
76
|
-
|
77
|
-
if (i0 >= ne0) {
|
78
|
-
return;
|
79
|
-
}
|
80
|
-
|
81
|
-
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
82
|
-
|
83
|
-
if (i0 >= n_dims) {
|
84
|
-
const int i = row*ne0 + i0;
|
85
|
-
|
86
|
-
dst[i + 0] = x[i + 0];
|
87
|
-
dst[i + 1] = x[i + 1];
|
88
|
-
|
89
|
-
return;
|
90
|
-
}
|
91
|
-
|
92
|
-
const int i = row*ne0 + i0/2;
|
93
|
-
const int i2 = row/p_delta_rows;
|
94
|
-
|
95
|
-
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
96
|
-
|
97
|
-
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
98
|
-
|
99
|
-
float cos_theta;
|
100
|
-
float sin_theta;
|
101
|
-
|
102
|
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
103
|
-
|
104
|
-
const float x0 = x[i + 0];
|
105
|
-
const float x1 = x[i + n_dims/2];
|
106
|
-
|
107
|
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
108
|
-
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
109
|
-
}
|
110
|
-
|
111
|
-
template<typename T>
|
112
|
-
static void rope_norm_cuda(
|
113
|
-
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
114
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
115
|
-
GGML_ASSERT(ne0 % 2 == 0);
|
116
|
-
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
117
|
-
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
118
|
-
const dim3 block_nums(nr, n_blocks_x, 1);
|
119
|
-
|
120
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
121
|
-
|
122
|
-
if (freq_factors == nullptr) {
|
123
|
-
rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
|
124
|
-
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
125
|
-
theta_scale, freq_factors
|
126
|
-
);
|
127
|
-
} else {
|
128
|
-
rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
|
129
|
-
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
130
|
-
theta_scale, freq_factors
|
131
|
-
);
|
132
|
-
}
|
133
|
-
}
|
134
|
-
|
135
|
-
template<typename T>
|
136
|
-
static void rope_neox_cuda(
|
137
|
-
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
138
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
139
|
-
GGML_ASSERT(ne0 % 2 == 0);
|
140
|
-
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
141
|
-
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
142
|
-
const dim3 block_nums(nr, n_blocks_x, 1);
|
143
|
-
|
144
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
145
|
-
|
146
|
-
if (freq_factors == nullptr) {
|
147
|
-
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
148
|
-
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
149
|
-
theta_scale, freq_factors
|
150
|
-
);
|
151
|
-
} else {
|
152
|
-
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
153
|
-
x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
154
|
-
theta_scale, freq_factors
|
155
|
-
);
|
156
|
-
}
|
157
|
-
}
|
158
|
-
|
159
|
-
static void rope_norm_cuda_f16(
|
160
|
-
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
161
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
162
|
-
|
163
|
-
rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
164
|
-
}
|
165
|
-
|
166
|
-
static void rope_norm_cuda_f32(
|
167
|
-
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
168
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
169
|
-
|
170
|
-
rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
171
|
-
}
|
172
|
-
|
173
|
-
static void rope_neox_cuda_f16(
|
174
|
-
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
175
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
176
|
-
|
177
|
-
rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
178
|
-
}
|
179
|
-
|
180
|
-
static void rope_neox_cuda_f32(
|
181
|
-
const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
|
182
|
-
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
183
|
-
) {
|
184
|
-
|
185
|
-
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
186
|
-
}
|
187
|
-
|
188
|
-
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
189
|
-
const ggml_tensor * src0 = dst->src[0];
|
190
|
-
const ggml_tensor * src1 = dst->src[1];
|
191
|
-
const ggml_tensor * src2 = dst->src[2];
|
192
|
-
|
193
|
-
const float * src0_d = (const float *)src0->data;
|
194
|
-
const float * src1_d = (const float *)src1->data;
|
195
|
-
|
196
|
-
float * dst_d = (float *)dst->data;
|
197
|
-
cudaStream_t stream = ctx.stream();
|
198
|
-
|
199
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
200
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
201
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
202
|
-
GGML_ASSERT(src0->type == dst->type);
|
203
|
-
|
204
|
-
const int64_t ne00 = src0->ne[0];
|
205
|
-
const int64_t ne01 = src0->ne[1];
|
206
|
-
const int64_t nr = ggml_nrows(src0);
|
207
|
-
|
208
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
209
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
210
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
211
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
212
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
213
|
-
|
214
|
-
// RoPE alteration for extended context
|
215
|
-
float freq_base;
|
216
|
-
float freq_scale;
|
217
|
-
float ext_factor;
|
218
|
-
float attn_factor;
|
219
|
-
float beta_fast;
|
220
|
-
float beta_slow;
|
221
|
-
|
222
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
223
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
224
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
225
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
226
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
227
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
228
|
-
|
229
|
-
const bool is_neox = mode & 2;
|
230
|
-
|
231
|
-
const int32_t * pos = (const int32_t *) src1_d;
|
232
|
-
|
233
|
-
const float * freq_factors = nullptr;
|
234
|
-
if (src2 != nullptr) {
|
235
|
-
freq_factors = (const float *) src2->data;
|
236
|
-
}
|
237
|
-
|
238
|
-
rope_corr_dims corr_dims;
|
239
|
-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
240
|
-
|
241
|
-
// compute
|
242
|
-
if (is_neox) {
|
243
|
-
if (src0->type == GGML_TYPE_F32) {
|
244
|
-
rope_neox_cuda_f32(
|
245
|
-
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
246
|
-
attn_factor, corr_dims, freq_factors, stream
|
247
|
-
);
|
248
|
-
} else if (src0->type == GGML_TYPE_F16) {
|
249
|
-
rope_neox_cuda_f16(
|
250
|
-
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
251
|
-
attn_factor, corr_dims, freq_factors, stream
|
252
|
-
);
|
253
|
-
} else {
|
254
|
-
GGML_ASSERT(false);
|
255
|
-
}
|
256
|
-
} else {
|
257
|
-
if (src0->type == GGML_TYPE_F32) {
|
258
|
-
rope_norm_cuda_f32(
|
259
|
-
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
260
|
-
attn_factor, corr_dims, freq_factors, stream
|
261
|
-
);
|
262
|
-
} else if (src0->type == GGML_TYPE_F16) {
|
263
|
-
rope_norm_cuda_f16(
|
264
|
-
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
265
|
-
attn_factor, corr_dims, freq_factors, stream
|
266
|
-
);
|
267
|
-
} else {
|
268
|
-
GGML_ASSERT(false);
|
269
|
-
}
|
270
|
-
}
|
271
|
-
}
|
@@ -1,31 +0,0 @@
|
|
1
|
-
#include "scale.cuh"
|
2
|
-
|
3
|
-
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
4
|
-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
5
|
-
|
6
|
-
if (i >= k) {
|
7
|
-
return;
|
8
|
-
}
|
9
|
-
|
10
|
-
dst[i] = scale * x[i];
|
11
|
-
}
|
12
|
-
|
13
|
-
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
14
|
-
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
15
|
-
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
16
|
-
}
|
17
|
-
|
18
|
-
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
19
|
-
const ggml_tensor * src0 = dst->src[0];
|
20
|
-
const float * src0_d = (const float *)src0->data;
|
21
|
-
float * dst_d = (float *)dst->data;
|
22
|
-
cudaStream_t stream = ctx.stream();
|
23
|
-
|
24
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
25
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
26
|
-
|
27
|
-
float scale;
|
28
|
-
memcpy(&scale, dst->op_params, sizeof(float));
|
29
|
-
|
30
|
-
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
31
|
-
}
|
@@ -1,206 +0,0 @@
|
|
1
|
-
#include "common.cuh"
|
2
|
-
#include "softmax.cuh"
|
3
|
-
|
4
|
-
template <typename T>
|
5
|
-
static __device__ __forceinline__ float t2f32(T val) {
|
6
|
-
return (float) val;
|
7
|
-
}
|
8
|
-
|
9
|
-
template <>
|
10
|
-
__device__ float __forceinline__ t2f32<half>(half val) {
|
11
|
-
return __half2float(val);
|
12
|
-
}
|
13
|
-
|
14
|
-
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
15
|
-
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
16
|
-
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
17
|
-
|
18
|
-
const int tid = threadIdx.x;
|
19
|
-
const int rowx = blockIdx.x;
|
20
|
-
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
21
|
-
|
22
|
-
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
23
|
-
|
24
|
-
const int warp_id = threadIdx.x / WARP_SIZE;
|
25
|
-
const int lane_id = threadIdx.x % WARP_SIZE;
|
26
|
-
|
27
|
-
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
|
28
|
-
|
29
|
-
extern __shared__ float data_soft_max_f32[];
|
30
|
-
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
31
|
-
// shared memory buffer to cache values between iterations:
|
32
|
-
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
33
|
-
|
34
|
-
float max_val = -INFINITY;
|
35
|
-
|
36
|
-
#pragma unroll
|
37
|
-
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
38
|
-
const int col = col0 + tid;
|
39
|
-
|
40
|
-
if (ncols_template == 0 && col >= ncols) {
|
41
|
-
break;
|
42
|
-
}
|
43
|
-
|
44
|
-
const int64_t ix = (int64_t)rowx*ncols + col;
|
45
|
-
const int64_t iy = (int64_t)rowy*ncols + col;
|
46
|
-
|
47
|
-
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
|
48
|
-
|
49
|
-
vals[col] = val;
|
50
|
-
max_val = max(max_val, val);
|
51
|
-
}
|
52
|
-
|
53
|
-
// find the max value in the block
|
54
|
-
max_val = warp_reduce_max(max_val);
|
55
|
-
if (block_size > WARP_SIZE) {
|
56
|
-
if (warp_id == 0) {
|
57
|
-
buf_iw[lane_id] = -INFINITY;
|
58
|
-
}
|
59
|
-
__syncthreads();
|
60
|
-
|
61
|
-
if (lane_id == 0) {
|
62
|
-
buf_iw[warp_id] = max_val;
|
63
|
-
}
|
64
|
-
__syncthreads();
|
65
|
-
|
66
|
-
max_val = buf_iw[lane_id];
|
67
|
-
max_val = warp_reduce_max(max_val);
|
68
|
-
}
|
69
|
-
|
70
|
-
float tmp = 0.0f; // partial sum
|
71
|
-
|
72
|
-
#pragma unroll
|
73
|
-
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
74
|
-
const int col = col0 + tid;
|
75
|
-
|
76
|
-
if (ncols_template == 0 && col >= ncols) {
|
77
|
-
break;
|
78
|
-
}
|
79
|
-
|
80
|
-
const float val = expf(vals[col] - max_val);
|
81
|
-
tmp += val;
|
82
|
-
vals[col] = val;
|
83
|
-
}
|
84
|
-
|
85
|
-
// find the sum of exps in the block
|
86
|
-
tmp = warp_reduce_sum(tmp);
|
87
|
-
if (block_size > WARP_SIZE) {
|
88
|
-
__syncthreads();
|
89
|
-
if (warp_id == 0) {
|
90
|
-
buf_iw[lane_id] = 0.0f;
|
91
|
-
}
|
92
|
-
__syncthreads();
|
93
|
-
|
94
|
-
if (lane_id == 0) {
|
95
|
-
buf_iw[warp_id] = tmp;
|
96
|
-
}
|
97
|
-
__syncthreads();
|
98
|
-
|
99
|
-
tmp = buf_iw[lane_id];
|
100
|
-
tmp = warp_reduce_sum(tmp);
|
101
|
-
}
|
102
|
-
|
103
|
-
const float inv_sum = 1.0f / tmp;
|
104
|
-
|
105
|
-
#pragma unroll
|
106
|
-
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
107
|
-
const int col = col0 + tid;
|
108
|
-
|
109
|
-
if (ncols_template == 0 && col >= ncols) {
|
110
|
-
return;
|
111
|
-
}
|
112
|
-
|
113
|
-
const int64_t idst = (int64_t)rowx*ncols + col;
|
114
|
-
dst[idst] = vals[col] * inv_sum;
|
115
|
-
}
|
116
|
-
}
|
117
|
-
|
118
|
-
template<typename T>
|
119
|
-
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
120
|
-
int nth = WARP_SIZE;
|
121
|
-
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
122
|
-
const dim3 block_dims(nth, 1, 1);
|
123
|
-
const dim3 block_nums(nrows_x, 1, 1);
|
124
|
-
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
125
|
-
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
126
|
-
|
127
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
128
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
129
|
-
|
130
|
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
131
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
132
|
-
|
133
|
-
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
134
|
-
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
135
|
-
switch (ncols_x) {
|
136
|
-
case 32:
|
137
|
-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
138
|
-
break;
|
139
|
-
case 64:
|
140
|
-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
141
|
-
break;
|
142
|
-
case 128:
|
143
|
-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
144
|
-
break;
|
145
|
-
case 256:
|
146
|
-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
147
|
-
break;
|
148
|
-
case 512:
|
149
|
-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
150
|
-
break;
|
151
|
-
case 1024:
|
152
|
-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
153
|
-
break;
|
154
|
-
case 2048:
|
155
|
-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
156
|
-
break;
|
157
|
-
case 4096:
|
158
|
-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
159
|
-
break;
|
160
|
-
default:
|
161
|
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
162
|
-
break;
|
163
|
-
}
|
164
|
-
} else {
|
165
|
-
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
166
|
-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
167
|
-
}
|
168
|
-
}
|
169
|
-
|
170
|
-
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
171
|
-
const ggml_tensor * src0 = dst->src[0];
|
172
|
-
const ggml_tensor * src1 = dst->src[1];
|
173
|
-
|
174
|
-
const float * src0_d = (const float *)src0->data;
|
175
|
-
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
176
|
-
|
177
|
-
float * dst_d = (float *)dst->data;
|
178
|
-
cudaStream_t stream = ctx.stream();
|
179
|
-
|
180
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
181
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
182
|
-
|
183
|
-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
184
|
-
|
185
|
-
const int64_t ne00 = src0->ne[0];
|
186
|
-
const int64_t nrows_x = ggml_nrows(src0);
|
187
|
-
const int64_t nrows_y = src0->ne[1];
|
188
|
-
|
189
|
-
float scale = 1.0f;
|
190
|
-
float max_bias = 0.0f;
|
191
|
-
|
192
|
-
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
193
|
-
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
194
|
-
|
195
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
196
|
-
|
197
|
-
if (use_f16) {
|
198
|
-
const half * src1_dd = (const half *)src1_d;
|
199
|
-
|
200
|
-
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
201
|
-
} else {
|
202
|
-
const float * src1_dd = (const float *)src1_d;
|
203
|
-
|
204
|
-
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
205
|
-
}
|
206
|
-
}
|
@@ -1,40 +0,0 @@
|
|
1
|
-
#include "sumrows.cuh"
|
2
|
-
|
3
|
-
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
4
|
-
const int row = blockIdx.x;
|
5
|
-
const int col = threadIdx.x;
|
6
|
-
|
7
|
-
float sum = 0.0f;
|
8
|
-
for (int i = col; i < ncols; i += blockDim.x) {
|
9
|
-
sum += x[row * ncols + i];
|
10
|
-
}
|
11
|
-
|
12
|
-
sum = warp_reduce_sum(sum);
|
13
|
-
|
14
|
-
if (col == 0) {
|
15
|
-
dst[row] = sum;
|
16
|
-
}
|
17
|
-
}
|
18
|
-
|
19
|
-
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
20
|
-
const dim3 block_dims(WARP_SIZE, 1, 1);
|
21
|
-
const dim3 block_nums(nrows, 1, 1);
|
22
|
-
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
23
|
-
}
|
24
|
-
|
25
|
-
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
26
|
-
const ggml_tensor * src0 = dst->src[0];
|
27
|
-
const float * src0_d = (const float *)src0->data;
|
28
|
-
float * dst_d = (float *)dst->data;
|
29
|
-
cudaStream_t stream = ctx.stream();
|
30
|
-
|
31
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
32
|
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
33
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
34
|
-
|
35
|
-
|
36
|
-
const int64_t ncols = src0->ne[0];
|
37
|
-
const int64_t nrows = ggml_nrows(src0);
|
38
|
-
|
39
|
-
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
40
|
-
}
|