llama_cpp 0.15.4 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +15 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +13 -1
  7. data/vendor/tmp/llama.cpp/Makefile +62 -35
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +8 -6
  131. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -6
  132. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  133. data/vendor/tmp/llama.cpp/ggml-metal.m +34 -24
  134. data/vendor/tmp/llama.cpp/ggml-metal.metal +83 -59
  135. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +2 -2
  136. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +7 -67
  137. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  138. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +456 -329
  139. data/vendor/tmp/llama.cpp/ggml.c +178 -330
  140. data/vendor/tmp/llama.cpp/ggml.h +9 -28
  141. data/vendor/tmp/llama.cpp/llama.cpp +242 -426
  142. data/vendor/tmp/llama.cpp/llama.h +17 -43
  143. metadata +121 -6
  144. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  145. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  146. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  147. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1654
1654
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
1655
  static void rope_yarn(
1656
1656
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657
- thread float * cos_theta, thread float * sin_theta
1658
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1659
1658
  // Get n-d rotational scaling corrected for extrapolation
1660
1659
  float theta_interp = freq_scale * theta_extrap;
1661
1660
  float theta = theta_interp;
@@ -1672,19 +1671,20 @@ static void rope_yarn(
1672
1671
 
1673
1672
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1674
1673
  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1675
- static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1676
- return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1674
+ static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
1675
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1677
1676
  }
1678
1677
 
1679
1678
  static void rope_yarn_corr_dims(
1680
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1679
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
1681
1680
  ) {
1682
1681
  // start and end correction dims
1683
- dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1684
- dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1682
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
1683
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
1685
1684
  }
1686
1685
 
1687
- typedef void (rope_t)(
1686
+ template<typename T>
1687
+ kernel void kernel_rope_norm(
1688
1688
  device const void * src0,
1689
1689
  device const int32_t * src1,
1690
1690
  device const float * src2,
@@ -1707,8 +1707,7 @@ typedef void (rope_t)(
1707
1707
  constant uint64_t & nb3,
1708
1708
  constant int & n_past,
1709
1709
  constant int & n_dims,
1710
- constant int & mode,
1711
- constant int & n_orig_ctx,
1710
+ constant int & n_ctx_orig,
1712
1711
  constant float & freq_base,
1713
1712
  constant float & freq_scale,
1714
1713
  constant float & ext_factor,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
1717
1716
  constant float & beta_slow,
1718
1717
  uint tiitg[[thread_index_in_threadgroup]],
1719
1718
  uint3 tptg[[threads_per_threadgroup]],
1720
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2];
1721
+ const int64_t i2 = tgpig[1];
1722
+ const int64_t i1 = tgpig[0];
1723
+
1724
+ float corr_dims[2];
1725
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float) pos[i2];
1730
+ const float inv_ndims = -1.f/n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2;
1738
+
1739
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742
+
1743
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0];
1749
+ const float x1 = src[1];
1750
+
1751
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0] = src[0];
1758
+ dst_data[1] = src[1];
1759
+ }
1760
+ }
1761
+ }
1721
1762
 
1722
1763
  template<typename T>
1723
- kernel void kernel_rope(
1764
+ kernel void kernel_rope_neox(
1724
1765
  device const void * src0,
1725
1766
  device const int32_t * src1,
1726
1767
  device const float * src2,
@@ -1743,8 +1784,7 @@ kernel void kernel_rope(
1743
1784
  constant uint64_t & nb3,
1744
1785
  constant int & n_past,
1745
1786
  constant int & n_dims,
1746
- constant int & mode,
1747
- constant int & n_orig_ctx,
1787
+ constant int & n_ctx_orig,
1748
1788
  constant float & freq_base,
1749
1789
  constant float & freq_scale,
1750
1790
  constant float & ext_factor,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
1758
1798
  const int64_t i2 = tgpig[1];
1759
1799
  const int64_t i1 = tgpig[0];
1760
1800
 
1761
- const bool is_neox = mode & 2;
1762
-
1763
1801
  float corr_dims[2];
1764
- rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1802
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1765
1803
 
1766
1804
  device const int32_t * pos = src1;
1767
1805
 
1768
- const int64_t p = pos[i2];
1769
-
1770
- const float theta_base = (float)p;
1806
+ const float theta_base = (float) pos[i2];
1771
1807
  const float inv_ndims = -1.f/n_dims;
1772
1808
 
1773
- if (!is_neox) {
1774
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1776
-
1777
- float cos_theta, sin_theta;
1778
- rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1779
-
1780
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782
-
1783
- const T x0 = src[0];
1784
- const T x1 = src[1];
1809
+ float cos_theta;
1810
+ float sin_theta;
1785
1811
 
1786
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1787
- dst_data[1] = x0*sin_theta + x1*cos_theta;
1788
- }
1789
- } else {
1790
- for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791
- if (ic < n_dims) {
1792
- const int64_t i0 = ic/2;
1812
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2;
1793
1815
 
1794
- const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795
-
1796
- const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1816
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1797
1817
 
1798
- float cos_theta, sin_theta;
1799
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1800
1819
 
1801
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1803
1821
 
1804
- const float x0 = src[0];
1805
- const float x1 = src[n_dims/2];
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
1806
1824
 
1807
- dst_data[0] = x0*cos_theta - x1*sin_theta;
1808
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809
- } else {
1810
- const int64_t i0 = ic;
1825
+ const float x0 = src[0];
1826
+ const float x1 = src[n_dims/2];
1811
1827
 
1812
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1814
1833
 
1815
- dst_data[0] = src[0];
1816
- dst_data[1] = src[1];
1817
- }
1834
+ dst_data[0] = src[0];
1835
+ dst_data[1] = src[1];
1818
1836
  }
1819
1837
  }
1820
1838
  }
1821
1839
 
1822
- template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823
- template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841
+ typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844
+ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847
+ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1824
1848
 
1825
1849
  typedef void (im2col_t)(
1826
1850
  device const float * x,
@@ -491,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
491
491
  if (remote_ptr != 0) {
492
492
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
493
493
  ggml_backend_rpc_buffer_interface,
494
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
494
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
495
495
  remote_size);
496
496
  return buffer;
497
497
  } else {
@@ -692,7 +692,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
692
692
  GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
693
693
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
694
694
  /* .endpoint = */ endpoint,
695
- /* .name = */ "RPC",
695
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
696
696
  };
697
697
 
698
698
  ggml_backend_t backend = new ggml_backend {
@@ -8928,49 +8928,6 @@ static void rope_neox(
8928
8928
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
8929
8929
  }
8930
8930
 
8931
- static void rope_glm_f32(
8932
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8933
- int n_ctx
8934
- , const sycl::nd_item<3> &item_ct1) {
8935
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8936
- item_ct1.get_local_id(2);
8937
- const int half_n_dims = ncols/4;
8938
-
8939
- if (col >= half_n_dims) {
8940
- return;
8941
- }
8942
-
8943
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8944
- item_ct1.get_local_id(1);
8945
- const int i = row*ncols + col;
8946
- const int i2 = row/p_delta_rows;
8947
-
8948
- const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
8949
- // FIXME: this is likely wrong
8950
- const int p = pos != nullptr ? pos[i2] : 0;
8951
-
8952
- const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
8953
- const float sin_theta = sycl::sin((float)theta);
8954
- const float cos_theta = sycl::cos((float)theta);
8955
-
8956
- const float x0 = x[i + 0];
8957
- const float x1 = x[i + half_n_dims];
8958
-
8959
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
8960
- dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
8961
-
8962
- const float block_theta =
8963
- ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
8964
- const float sin_block_theta = sycl::sin((float)block_theta);
8965
- const float cos_block_theta = sycl::cos((float)block_theta);
8966
-
8967
- const float x2 = x[i + half_n_dims * 2];
8968
- const float x3 = x[i + half_n_dims * 3];
8969
-
8970
- dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
8971
- dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
8972
- }
8973
-
8974
8931
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8975
8932
  const sycl::nd_item<3> &item_ct1) {
8976
8933
  const int row = item_ct1.get_group(1);
@@ -12520,22 +12477,6 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12520
12477
  }
12521
12478
  }
12522
12479
 
12523
- static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12524
- const int32_t *pos, float freq_scale,
12525
- int p_delta_rows, float freq_base, int n_ctx,
12526
- dpct::queue_ptr stream) {
12527
- GGML_ASSERT(ncols % 4 == 0);
12528
- const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
12529
- const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
12530
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12531
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12532
- [=](sycl::nd_item<3> item_ct1) {
12533
- rope_glm_f32(x, dst, ncols, pos, freq_scale,
12534
- p_delta_rows, freq_base, n_ctx,
12535
- item_ct1);
12536
- });
12537
- }
12538
-
12539
12480
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12540
12481
  const int nrows, dpct::queue_ptr stream) {
12541
12482
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -14066,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14066
14007
  //const int n_past = ((int32_t *) dst->op_params)[0];
14067
14008
  const int n_dims = ((int32_t *) dst->op_params)[1];
14068
14009
  const int mode = ((int32_t *) dst->op_params)[2];
14069
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14070
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14010
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14011
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14071
14012
 
14072
14013
  // RoPE alteration for extended context
14073
14014
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14087,7 +14028,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14087
14028
  }
14088
14029
 
14089
14030
  const bool is_neox = mode & 2;
14090
- const bool is_glm = mode & 4;
14031
+
14032
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
14033
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
14091
14034
 
14092
14035
  if (is_neox) {
14093
14036
  pos = (const int32_t *) src1_dd;
@@ -14100,13 +14043,10 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14100
14043
  }
14101
14044
 
14102
14045
  rope_corr_dims corr_dims;
14103
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14046
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
14104
14047
 
14105
14048
  // compute
14106
- if (is_glm) {
14107
- GGML_ASSERT(false);
14108
- rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
14109
- } else if (is_neox) {
14049
+ if (is_neox) {
14110
14050
  if (src0->type == GGML_TYPE_F32) {
14111
14051
  rope_neox_sycl(
14112
14052
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,