llama_cpp 0.15.4 → 0.16.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/extconf.rb +1 -2
- data/ext/llama_cpp/llama_cpp.cpp +15 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +13 -1
- data/vendor/tmp/llama.cpp/Makefile +62 -35
- data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
- data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
- data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
- data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
- data/vendor/tmp/llama.cpp/ggml.c +178 -330
- data/vendor/tmp/llama.cpp/ggml.h +9 -28
- data/vendor/tmp/llama.cpp/llama.cpp +242 -426
- data/vendor/tmp/llama.cpp/llama.h +17 -43
- metadata +121 -6
- data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
- data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
- data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
|
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
1654
1654
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1655
1655
|
static void rope_yarn(
|
|
1656
1656
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1657
|
-
thread float * cos_theta, thread float * sin_theta
|
|
1658
|
-
) {
|
|
1657
|
+
thread float * cos_theta, thread float * sin_theta) {
|
|
1659
1658
|
// Get n-d rotational scaling corrected for extrapolation
|
|
1660
1659
|
float theta_interp = freq_scale * theta_extrap;
|
|
1661
1660
|
float theta = theta_interp;
|
|
@@ -1672,19 +1671,20 @@ static void rope_yarn(
|
|
|
1672
1671
|
|
|
1673
1672
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1674
1673
|
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1675
|
-
static float rope_yarn_corr_factor(int n_dims, int
|
|
1676
|
-
return n_dims * log(
|
|
1674
|
+
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
|
1675
|
+
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1677
1676
|
}
|
|
1678
1677
|
|
|
1679
1678
|
static void rope_yarn_corr_dims(
|
|
1680
|
-
int n_dims, int
|
|
1679
|
+
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1681
1680
|
) {
|
|
1682
1681
|
// start and end correction dims
|
|
1683
|
-
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims,
|
|
1684
|
-
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims,
|
|
1682
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
|
1683
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
|
1685
1684
|
}
|
|
1686
1685
|
|
|
1687
|
-
|
|
1686
|
+
template<typename T>
|
|
1687
|
+
kernel void kernel_rope_norm(
|
|
1688
1688
|
device const void * src0,
|
|
1689
1689
|
device const int32_t * src1,
|
|
1690
1690
|
device const float * src2,
|
|
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
|
|
|
1707
1707
|
constant uint64_t & nb3,
|
|
1708
1708
|
constant int & n_past,
|
|
1709
1709
|
constant int & n_dims,
|
|
1710
|
-
constant int &
|
|
1711
|
-
constant int & n_orig_ctx,
|
|
1710
|
+
constant int & n_ctx_orig,
|
|
1712
1711
|
constant float & freq_base,
|
|
1713
1712
|
constant float & freq_scale,
|
|
1714
1713
|
constant float & ext_factor,
|
|
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
|
|
|
1717
1716
|
constant float & beta_slow,
|
|
1718
1717
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
1719
1718
|
uint3 tptg[[threads_per_threadgroup]],
|
|
1720
|
-
uint3 tgpig[[threadgroup_position_in_grid]])
|
|
1719
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
1720
|
+
const int64_t i3 = tgpig[2];
|
|
1721
|
+
const int64_t i2 = tgpig[1];
|
|
1722
|
+
const int64_t i1 = tgpig[0];
|
|
1723
|
+
|
|
1724
|
+
float corr_dims[2];
|
|
1725
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1726
|
+
|
|
1727
|
+
device const int32_t * pos = src1;
|
|
1728
|
+
|
|
1729
|
+
const float theta_base = (float) pos[i2];
|
|
1730
|
+
const float inv_ndims = -1.f/n_dims;
|
|
1731
|
+
|
|
1732
|
+
float cos_theta;
|
|
1733
|
+
float sin_theta;
|
|
1734
|
+
|
|
1735
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1736
|
+
if (i0 < n_dims) {
|
|
1737
|
+
const int64_t ic = i0/2;
|
|
1738
|
+
|
|
1739
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1740
|
+
|
|
1741
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1742
|
+
|
|
1743
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1744
|
+
|
|
1745
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1746
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1747
|
+
|
|
1748
|
+
const float x0 = src[0];
|
|
1749
|
+
const float x1 = src[1];
|
|
1750
|
+
|
|
1751
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1752
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
1753
|
+
} else {
|
|
1754
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1755
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1756
|
+
|
|
1757
|
+
dst_data[0] = src[0];
|
|
1758
|
+
dst_data[1] = src[1];
|
|
1759
|
+
}
|
|
1760
|
+
}
|
|
1761
|
+
}
|
|
1721
1762
|
|
|
1722
1763
|
template<typename T>
|
|
1723
|
-
kernel void
|
|
1764
|
+
kernel void kernel_rope_neox(
|
|
1724
1765
|
device const void * src0,
|
|
1725
1766
|
device const int32_t * src1,
|
|
1726
1767
|
device const float * src2,
|
|
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
|
|
|
1743
1784
|
constant uint64_t & nb3,
|
|
1744
1785
|
constant int & n_past,
|
|
1745
1786
|
constant int & n_dims,
|
|
1746
|
-
constant int &
|
|
1747
|
-
constant int & n_orig_ctx,
|
|
1787
|
+
constant int & n_ctx_orig,
|
|
1748
1788
|
constant float & freq_base,
|
|
1749
1789
|
constant float & freq_scale,
|
|
1750
1790
|
constant float & ext_factor,
|
|
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
|
|
|
1758
1798
|
const int64_t i2 = tgpig[1];
|
|
1759
1799
|
const int64_t i1 = tgpig[0];
|
|
1760
1800
|
|
|
1761
|
-
const bool is_neox = mode & 2;
|
|
1762
|
-
|
|
1763
1801
|
float corr_dims[2];
|
|
1764
|
-
rope_yarn_corr_dims(n_dims,
|
|
1802
|
+
rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1765
1803
|
|
|
1766
1804
|
device const int32_t * pos = src1;
|
|
1767
1805
|
|
|
1768
|
-
const
|
|
1769
|
-
|
|
1770
|
-
const float theta_base = (float)p;
|
|
1806
|
+
const float theta_base = (float) pos[i2];
|
|
1771
1807
|
const float inv_ndims = -1.f/n_dims;
|
|
1772
1808
|
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1776
|
-
|
|
1777
|
-
float cos_theta, sin_theta;
|
|
1778
|
-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1779
|
-
|
|
1780
|
-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1781
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1782
|
-
|
|
1783
|
-
const T x0 = src[0];
|
|
1784
|
-
const T x1 = src[1];
|
|
1809
|
+
float cos_theta;
|
|
1810
|
+
float sin_theta;
|
|
1785
1811
|
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
} else {
|
|
1790
|
-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
|
1791
|
-
if (ic < n_dims) {
|
|
1792
|
-
const int64_t i0 = ic/2;
|
|
1812
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
1813
|
+
if (i0 < n_dims) {
|
|
1814
|
+
const int64_t ic = i0/2;
|
|
1793
1815
|
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
|
1816
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
|
1797
1817
|
|
|
1798
|
-
|
|
1799
|
-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1818
|
+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
|
1800
1819
|
|
|
1801
|
-
|
|
1802
|
-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1820
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
1803
1821
|
|
|
1804
|
-
|
|
1805
|
-
|
|
1822
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
1823
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
1806
1824
|
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
} else {
|
|
1810
|
-
const int64_t i0 = ic;
|
|
1825
|
+
const float x0 = src[0];
|
|
1826
|
+
const float x1 = src[n_dims/2];
|
|
1811
1827
|
|
|
1812
|
-
|
|
1813
|
-
|
|
1828
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
1829
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
1830
|
+
} else {
|
|
1831
|
+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
1832
|
+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
1814
1833
|
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
}
|
|
1834
|
+
dst_data[0] = src[0];
|
|
1835
|
+
dst_data[1] = src[1];
|
|
1818
1836
|
}
|
|
1819
1837
|
}
|
|
1820
1838
|
}
|
|
1821
1839
|
|
|
1822
|
-
|
|
1823
|
-
|
|
1840
|
+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
|
1841
|
+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
1842
|
+
|
|
1843
|
+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
|
1844
|
+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
1845
|
+
|
|
1846
|
+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
|
1847
|
+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
|
1824
1848
|
|
|
1825
1849
|
typedef void (im2col_t)(
|
|
1826
1850
|
device const float * x,
|
|
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
|
491
491
|
if (remote_ptr != 0) {
|
|
492
492
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
493
493
|
ggml_backend_rpc_buffer_interface,
|
|
494
|
-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
|
494
|
+
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
|
495
495
|
remote_size);
|
|
496
496
|
return buffer;
|
|
497
497
|
} else {
|
|
@@ -692,7 +692,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|
|
692
692
|
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
693
693
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
694
694
|
/* .endpoint = */ endpoint,
|
|
695
|
-
/* .name = */ "RPC",
|
|
695
|
+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
696
696
|
};
|
|
697
697
|
|
|
698
698
|
ggml_backend_t backend = new ggml_backend {
|
|
@@ -8928,49 +8928,6 @@ static void rope_neox(
|
|
|
8928
8928
|
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
8929
8929
|
}
|
|
8930
8930
|
|
|
8931
|
-
static void rope_glm_f32(
|
|
8932
|
-
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
|
|
8933
|
-
int n_ctx
|
|
8934
|
-
, const sycl::nd_item<3> &item_ct1) {
|
|
8935
|
-
const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
8936
|
-
item_ct1.get_local_id(2);
|
|
8937
|
-
const int half_n_dims = ncols/4;
|
|
8938
|
-
|
|
8939
|
-
if (col >= half_n_dims) {
|
|
8940
|
-
return;
|
|
8941
|
-
}
|
|
8942
|
-
|
|
8943
|
-
const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
8944
|
-
item_ct1.get_local_id(1);
|
|
8945
|
-
const int i = row*ncols + col;
|
|
8946
|
-
const int i2 = row/p_delta_rows;
|
|
8947
|
-
|
|
8948
|
-
const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
|
|
8949
|
-
// FIXME: this is likely wrong
|
|
8950
|
-
const int p = pos != nullptr ? pos[i2] : 0;
|
|
8951
|
-
|
|
8952
|
-
const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
|
|
8953
|
-
const float sin_theta = sycl::sin((float)theta);
|
|
8954
|
-
const float cos_theta = sycl::cos((float)theta);
|
|
8955
|
-
|
|
8956
|
-
const float x0 = x[i + 0];
|
|
8957
|
-
const float x1 = x[i + half_n_dims];
|
|
8958
|
-
|
|
8959
|
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
8960
|
-
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
|
8961
|
-
|
|
8962
|
-
const float block_theta =
|
|
8963
|
-
((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
|
|
8964
|
-
const float sin_block_theta = sycl::sin((float)block_theta);
|
|
8965
|
-
const float cos_block_theta = sycl::cos((float)block_theta);
|
|
8966
|
-
|
|
8967
|
-
const float x2 = x[i + half_n_dims * 2];
|
|
8968
|
-
const float x3 = x[i + half_n_dims * 3];
|
|
8969
|
-
|
|
8970
|
-
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
|
|
8971
|
-
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
|
|
8972
|
-
}
|
|
8973
|
-
|
|
8974
8931
|
static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
|
|
8975
8932
|
const sycl::nd_item<3> &item_ct1) {
|
|
8976
8933
|
const int row = item_ct1.get_group(1);
|
|
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
|
|
|
12520
12477
|
}
|
|
12521
12478
|
}
|
|
12522
12479
|
|
|
12523
|
-
static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
|
|
12524
|
-
const int32_t *pos, float freq_scale,
|
|
12525
|
-
int p_delta_rows, float freq_base, int n_ctx,
|
|
12526
|
-
dpct::queue_ptr stream) {
|
|
12527
|
-
GGML_ASSERT(ncols % 4 == 0);
|
|
12528
|
-
const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
|
|
12529
|
-
const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
|
|
12530
|
-
const sycl::range<3> block_nums(1, nrows, num_blocks_x);
|
|
12531
|
-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
12532
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
12533
|
-
rope_glm_f32(x, dst, ncols, pos, freq_scale,
|
|
12534
|
-
p_delta_rows, freq_base, n_ctx,
|
|
12535
|
-
item_ct1);
|
|
12536
|
-
});
|
|
12537
|
-
}
|
|
12538
|
-
|
|
12539
12480
|
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
|
12540
12481
|
const int nrows, dpct::queue_ptr stream) {
|
|
12541
12482
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
|
14066
14007
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
14067
14008
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
14068
14009
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
14069
|
-
const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
14070
|
-
const int
|
|
14010
|
+
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
14011
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
14071
14012
|
|
|
14072
14013
|
// RoPE alteration for extended context
|
|
14073
14014
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
|
14087
14028
|
}
|
|
14088
14029
|
|
|
14089
14030
|
const bool is_neox = mode & 2;
|
|
14090
|
-
|
|
14031
|
+
|
|
14032
|
+
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
|
14033
|
+
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
|
14091
14034
|
|
|
14092
14035
|
if (is_neox) {
|
|
14093
14036
|
pos = (const int32_t *) src1_dd;
|
|
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
|
|
14100
14043
|
}
|
|
14101
14044
|
|
|
14102
14045
|
rope_corr_dims corr_dims;
|
|
14103
|
-
ggml_rope_yarn_corr_dims(n_dims,
|
|
14046
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
|
14104
14047
|
|
|
14105
14048
|
// compute
|
|
14106
|
-
if (
|
|
14107
|
-
GGML_ASSERT(false);
|
|
14108
|
-
rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
|
|
14109
|
-
} else if (is_neox) {
|
|
14049
|
+
if (is_neox) {
|
|
14110
14050
|
if (src0->type == GGML_TYPE_F32) {
|
|
14111
14051
|
rope_neox_sycl(
|
|
14112
14052
|
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|