llama_cpp 0.15.4 → 0.16.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/extconf.rb +1 -2
- data/ext/llama_cpp/llama_cpp.cpp +15 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +13 -1
- data/vendor/tmp/llama.cpp/Makefile +62 -35
- data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
- data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
- data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
- data/vendor/tmp/llama.cpp/ggml.c +178 -330
- data/vendor/tmp/llama.cpp/ggml.h +9 -28
- data/vendor/tmp/llama.cpp/llama.cpp +242 -426
- data/vendor/tmp/llama.cpp/llama.h +17 -43
- metadata +121 -6
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1654
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1655
1655
|
static void rope_yarn(
|
1656
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1657
|
-
thread float * cos_theta, thread float * sin_theta
|
1658
|
-
) {
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
1659
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
1660
1659
|
float theta_interp = freq_scale * theta_extrap;
|
1661
1660
|
float theta = theta_interp;
|
@@ -1672,19 +1671,20 @@ static void rope_yarn(
|
|
1672
1671
|
|
1673
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1674
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1675
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
1676
|
-
return n_dims * log(
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1677
1676
|
}
|
1678
1677
|
|
1679
1678
|
static void rope_yarn_corr_dims(
|
1680
|
-
int n_dims, int
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1681
1680
|
) {
|
1682
1681
|
// start and end correction dims
|
1683
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
1684
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
1685
1684
|
}
|
1686
1685
|
|
1687
|
-
|
1686
|
+
template<typename T>
|
1687
|
+
kernel void kernel_rope_norm(
|
1688
1688
|
device const void * src0,
|
1689
1689
|
device const int32_t * src1,
|
1690
1690
|
device const float * src2,
|
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
|
|
1707
1707
|
constant uint64_t & nb3,
|
1708
1708
|
constant int & n_past,
|
1709
1709
|
constant int & n_dims,
|
1710
|
-
constant int &
|
1711
|
-
constant int & n_orig_ctx,
|
1710
|
+
constant int & n_ctx_orig,
|
1712
1711
|
constant float & freq_base,
|
1713
1712
|
constant float & freq_scale,
|
1714
1713
|
constant float & ext_factor,
|
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
|
|
1717
1716
|
constant float & beta_slow,
|
1718
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
1719
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
1720
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
1720
|
+
const int64_t i3 = tgpig[2];
|
1721
|
+
const int64_t i2 = tgpig[1];
|
1722
|
+
const int64_t i1 = tgpig[0];
|
1723
|
+
|
1724
|
+
float corr_dims[2];
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1726
|
+
|
1727
|
+
device const int32_t * pos = src1;
|
1728
|
+
|
1729
|
+
const float theta_base = (float) pos[i2];
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
1731
|
+
|
1732
|
+
float cos_theta;
|
1733
|
+
float sin_theta;
|
1734
|
+
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1736
|
+
if (i0 < n_dims) {
|
1737
|
+
const int64_t ic = i0/2;
|
1738
|
+
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1740
|
+
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1742
|
+
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1744
|
+
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1747
|
+
|
1748
|
+
const float x0 = src[0];
|
1749
|
+
const float x1 = src[1];
|
1750
|
+
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
1753
|
+
} else {
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1756
|
+
|
1757
|
+
dst_data[0] = src[0];
|
1758
|
+
dst_data[1] = src[1];
|
1759
|
+
}
|
1760
|
+
}
|
1761
|
+
}
|
1721
1762
|
|
1722
1763
|
template<typename T>
|
1723
|
-
kernel void
|
1764
|
+
kernel void kernel_rope_neox(
|
1724
1765
|
device const void * src0,
|
1725
1766
|
device const int32_t * src1,
|
1726
1767
|
device const float * src2,
|
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
|
|
1743
1784
|
constant uint64_t & nb3,
|
1744
1785
|
constant int & n_past,
|
1745
1786
|
constant int & n_dims,
|
1746
|
-
constant int &
|
1747
|
-
constant int & n_orig_ctx,
|
1787
|
+
constant int & n_ctx_orig,
|
1748
1788
|
constant float & freq_base,
|
1749
1789
|
constant float & freq_scale,
|
1750
1790
|
constant float & ext_factor,
|
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
|
|
1758
1798
|
const int64_t i2 = tgpig[1];
|
1759
1799
|
const int64_t i1 = tgpig[0];
|
1760
1800
|
|
1761
|
-
const bool is_neox = mode & 2;
|
1762
|
-
|
1763
1801
|
float corr_dims[2];
|
1764
|
-
rope_yarn_corr_dims(n_dims,
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
1765
1803
|
|
1766
1804
|
device const int32_t * pos = src1;
|
1767
1805
|
|
1768
|
-
const
|
1769
|
-
|
1770
|
-
const float theta_base = (float)p;
|
1806
|
+
const float theta_base = (float) pos[i2];
|
1771
1807
|
const float inv_ndims = -1.f/n_dims;
|
1772
1808
|
|
1773
|
-
|
1774
|
-
|
1775
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1776
|
-
|
1777
|
-
float cos_theta, sin_theta;
|
1778
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1779
|
-
|
1780
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1781
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1782
|
-
|
1783
|
-
const T x0 = src[0];
|
1784
|
-
const T x1 = src[1];
|
1809
|
+
float cos_theta;
|
1810
|
+
float sin_theta;
|
1785
1811
|
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
} else {
|
1790
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1791
|
-
if (ic < n_dims) {
|
1792
|
-
const int64_t i0 = ic/2;
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1813
|
+
if (i0 < n_dims) {
|
1814
|
+
const int64_t ic = i0/2;
|
1793
1815
|
|
1794
|
-
|
1795
|
-
|
1796
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1797
1817
|
|
1798
|
-
|
1799
|
-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
1800
1819
|
|
1801
|
-
|
1802
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1803
1821
|
|
1804
|
-
|
1805
|
-
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
1806
1824
|
|
1807
|
-
|
1808
|
-
|
1809
|
-
} else {
|
1810
|
-
const int64_t i0 = ic;
|
1825
|
+
const float x0 = src[0];
|
1826
|
+
const float x1 = src[n_dims/2];
|
1811
1827
|
|
1812
|
-
|
1813
|
-
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
1830
|
+
} else {
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1814
1833
|
|
1815
|
-
|
1816
|
-
|
1817
|
-
}
|
1834
|
+
dst_data[0] = src[0];
|
1835
|
+
dst_data[1] = src[1];
|
1818
1836
|
}
|
1819
1837
|
}
|
1820
1838
|
}
|
1821
1839
|
|
1822
|
-
|
1823
|
-
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
1842
|
+
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
1845
|
+
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
1824
1848
|
|
1825
1849
|
typedef void (im2col_t)(
|
1826
1850
|
device const float * x,
|
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
491
491
|
if (remote_ptr != 0) {
|
492
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
493
493
|
ggml_backend_rpc_buffer_interface,
|
494
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
495
495
|
remote_size);
|
496
496
|
return buffer;
|
497
497
|
} else {
|
@@ -692,7 +692,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
692
692
|
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
693
693
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
694
694
|
/* .endpoint = */ endpoint,
|
695
|
-
/* .name = */ "RPC",
|
695
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
696
696
|
};
|
697
697
|
|
698
698
|
ggml_backend_t backend = new ggml_backend {
|
@@ -8928,49 +8928,6 @@ static void rope_neox(
|
|
8928
8928
|
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
8929
8929
|
}
|
8930
8930
|
|
8931
|
-
static void rope_glm_f32(
|
8932
|
-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
8933
|
-
int n_ctx
|
8934
|
-
, const sycl::nd_item<3> &item_ct1) {
|
8935
|
-
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
8936
|
-
item_ct1.get_local_id(2);
|
8937
|
-
const int half_n_dims = ncols/4;
|
8938
|
-
|
8939
|
-
if (col >= half_n_dims) {
|
8940
|
-
return;
|
8941
|
-
}
|
8942
|
-
|
8943
|
-
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
8944
|
-
item_ct1.get_local_id(1);
|
8945
|
-
const int i = row*ncols + col;
|
8946
|
-
const int i2 = row/p_delta_rows;
|
8947
|
-
|
8948
|
-
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
|
8949
|
-
// FIXME: this is likely wrong
|
8950
|
-
const int p = pos != nullptr ? pos[i2] : 0;
|
8951
|
-
|
8952
|
-
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
|
8953
|
-
const float sin_theta = sycl::sin((float)theta);
|
8954
|
-
const float cos_theta = sycl::cos((float)theta);
|
8955
|
-
|
8956
|
-
const float x0 = x[i + 0];
|
8957
|
-
const float x1 = x[i + half_n_dims];
|
8958
|
-
|
8959
|
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
8960
|
-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
8961
|
-
|
8962
|
-
const float block_theta =
|
8963
|
-
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
|
8964
|
-
const float sin_block_theta = sycl::sin((float)block_theta);
|
8965
|
-
const float cos_block_theta = sycl::cos((float)block_theta);
|
8966
|
-
|
8967
|
-
const float x2 = x[i + half_n_dims * 2];
|
8968
|
-
const float x3 = x[i + half_n_dims * 3];
|
8969
|
-
|
8970
|
-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
8971
|
-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
8972
|
-
}
|
8973
|
-
|
8974
8931
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
8975
8932
|
const sycl::nd_item<3> &item_ct1) {
|
8976
8933
|
const int row = item_ct1.get_group(1);
|
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
|
|
12520
12477
|
}
|
12521
12478
|
}
|
12522
12479
|
|
12523
|
-
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
12524
|
-
const int32_t *pos, float freq_scale,
|
12525
|
-
int p_delta_rows, float freq_base, int n_ctx,
|
12526
|
-
dpct::queue_ptr stream) {
|
12527
|
-
GGML_ASSERT(ncols % 4 == 0);
|
12528
|
-
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
|
12529
|
-
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
|
12530
|
-
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
12531
|
-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
12532
|
-
[=](sycl::nd_item<3> item_ct1) {
|
12533
|
-
rope_glm_f32(x, dst, ncols, pos, freq_scale,
|
12534
|
-
p_delta_rows, freq_base, n_ctx,
|
12535
|
-
item_ct1);
|
12536
|
-
});
|
12537
|
-
}
|
12538
|
-
|
12539
12480
|
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
12540
12481
|
const int nrows, dpct::queue_ptr stream) {
|
12541
12482
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
14066
14007
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
14067
14008
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
14068
14009
|
const int mode = ((int32_t *) dst->op_params)[2];
|
14069
|
-
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
14070
|
-
const int
|
14010
|
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
14011
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
14071
14012
|
|
14072
14013
|
// RoPE alteration for extended context
|
14073
14014
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
14087
14028
|
}
|
14088
14029
|
|
14089
14030
|
const bool is_neox = mode & 2;
|
14090
|
-
|
14031
|
+
|
14032
|
+
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
14033
|
+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
14091
14034
|
|
14092
14035
|
if (is_neox) {
|
14093
14036
|
pos = (const int32_t *) src1_dd;
|
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
14100
14043
|
}
|
14101
14044
|
|
14102
14045
|
rope_corr_dims corr_dims;
|
14103
|
-
ggml_rope_yarn_corr_dims(n_dims,
|
14046
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
14104
14047
|
|
14105
14048
|
// compute
|
14106
|
-
if (
|
14107
|
-
GGML_ASSERT(false);
|
14108
|
-
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
14109
|
-
} else if (is_neox) {
|
14049
|
+
if (is_neox) {
|
14110
14050
|
if (src0->type == GGML_TYPE_F32) {
|
14111
14051
|
rope_neox_sycl(
|
14112
14052
|
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|