llama_cpp 0.15.3 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -5,6 +5,7 @@
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
7
 
8
+
8
9
  #if defined(_MSC_VER) || defined(__MINGW32__)
9
10
  #include <malloc.h> // using malloc.h with MSC/MINGW
10
11
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -28,6 +29,10 @@
28
29
  #include <syscall.h>
29
30
  #endif
30
31
 
32
+ #ifdef GGML_USE_OPENMP
33
+ #include <omp.h>
34
+ #endif
35
+
31
36
  #ifdef GGML_USE_METAL
32
37
  #include <unistd.h>
33
38
  #endif
@@ -60,6 +65,9 @@
60
65
 
61
66
  typedef volatile LONG atomic_int;
62
67
  typedef atomic_int atomic_bool;
68
+ typedef atomic_int atomic_flag;
69
+
70
+ #define ATOMIC_FLAG_INIT 0
63
71
 
64
72
  static void atomic_store(atomic_int * ptr, LONG val) {
65
73
  InterlockedExchange(ptr, val);
@@ -73,6 +81,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
73
81
  static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
74
82
  return atomic_fetch_add(ptr, -(dec));
75
83
  }
84
+ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
85
+ return InterlockedExchange(ptr, 1);
86
+ }
87
+ static void atomic_flag_clear(atomic_flag * ptr) {
88
+ InterlockedExchange(ptr, 0);
89
+ }
76
90
 
77
91
  typedef HANDLE pthread_t;
78
92
 
@@ -283,17 +297,12 @@ inline static void * ggml_calloc(size_t num, size_t size) {
283
297
 
284
298
  #if defined(GGML_USE_ACCELERATE)
285
299
  #include <Accelerate/Accelerate.h>
286
- #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
287
- #include "ggml-opencl.h"
288
- #endif
289
300
  #elif defined(GGML_USE_OPENBLAS)
290
301
  #if defined(GGML_BLAS_USE_MKL)
291
302
  #include <mkl.h>
292
303
  #else
293
304
  #include <cblas.h>
294
305
  #endif
295
- #elif defined(GGML_USE_CLBLAST)
296
- #include "ggml-opencl.h"
297
306
  #endif
298
307
 
299
308
  // floating point type used to accumulate sums
@@ -1567,11 +1576,11 @@ do { \
1567
1576
 
1568
1577
  // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
1578
 
1570
- #define GGML_F32Cx8 __m256
1579
+ #define GGML_F32Cx8 __m256
1571
1580
  #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
1581
  #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
1582
 
1574
- static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1583
+ static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1575
1584
  float tmp[8];
1576
1585
 
1577
1586
  for (int i = 0; i < 8; i++) {
@@ -1580,13 +1589,14 @@ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1580
1589
 
1581
1590
  return (__m256)__lasx_xvld(tmp, 0);
1582
1591
  }
1583
- static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1592
+ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1584
1593
  float arr[8];
1585
1594
 
1586
1595
  __lasx_xvst(y, arr, 0);
1587
1596
 
1588
- for (int i = 0; i < 8; i++)
1597
+ for (int i = 0; i < 8; i++) {
1589
1598
  x[i] = GGML_FP32_TO_FP16(arr[i]);
1599
+ }
1590
1600
  }
1591
1601
  #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
1602
  #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1662,7 +1672,7 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1662
1672
  #define GGML_F16_STEP 32
1663
1673
  #define GGML_F16_EPR 4
1664
1674
 
1665
- static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1675
+ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1666
1676
  float tmp[4];
1667
1677
 
1668
1678
  tmp[0] = GGML_FP16_TO_FP32(x[0]);
@@ -1673,7 +1683,7 @@ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1673
1683
  return __lsx_vld(tmp, 0);
1674
1684
  }
1675
1685
 
1676
- static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1686
+ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1677
1687
  float arr[4];
1678
1688
 
1679
1689
  __lsx_vst(y, arr, 0);
@@ -1746,7 +1756,7 @@ struct ggml_compute_state_shared {
1746
1756
  int64_t perf_node_start_cycles;
1747
1757
  int64_t perf_node_start_time_us;
1748
1758
 
1749
- const int n_threads;
1759
+ int n_threads;
1750
1760
 
1751
1761
  // synchronization primitives
1752
1762
  atomic_int n_active; // num active threads
@@ -2257,6 +2267,11 @@ inline static float ggml_silu_f32(float x) {
2257
2267
  return x/(1.0f + expf(-x));
2258
2268
  }
2259
2269
 
2270
+ #if __FINITE_MATH_ONLY__
2271
+ #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
2272
+ #error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
2273
+ #endif
2274
+
2260
2275
  #if defined(__ARM_NEON) && defined(__aarch64__)
2261
2276
 
2262
2277
  // adapted from arm limited optimized routine
@@ -2306,32 +2321,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
2306
2321
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
2307
2322
  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2308
2323
  const __m512 n = _mm512_sub_ps(z, r);
2309
- const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2310
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2311
- const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2312
- const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2313
- const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2314
- const __m512 u = _mm512_mul_ps(b, b);
2315
- const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2316
- _mm512_set1_ps(0x1.573e2ep-5f)), u,
2317
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2318
- _mm512_set1_ps(0x1.fffdb6p-2f))),
2319
- u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2320
- if (_mm512_kortestz(c, c))
2321
- return _mm512_fmadd_ps(j, k, k);
2322
- const __m512i g = _mm512_and_si512(
2323
- _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2324
- _mm512_set1_epi32(0x82000000u));
2325
- const __m512 s1 =
2326
- _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2327
- const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2324
+ const __m512 b =
2325
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2326
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2328
2327
  const __mmask16 d =
2329
2328
  _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2330
- return _mm512_mask_blend_ps(
2331
- d, _mm512_mask_blend_ps(
2332
- c, _mm512_fmadd_ps(k, j, k),
2333
- _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2334
- _mm512_mul_ps(s1, s1));
2329
+ const __m512 u = _mm512_mul_ps(b, b);
2330
+ const __m512 j = _mm512_fmadd_ps(
2331
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2332
+ _mm512_set1_ps(0x1.573e2ep-5f)),
2333
+ u,
2334
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2335
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2336
+ u,
2337
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
2338
+ const __m512 res = _mm512_scalef_ps(j, n);
2339
+ if (_mm512_kortestz(d, d))
2340
+ return res;
2341
+ const __m512 zero = _mm512_setzero_ps();
2342
+ const __m512 alt = _mm512_mask_blend_ps(
2343
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
2344
+ return _mm512_mask_blend_ps(d, res, alt);
2335
2345
  }
2336
2346
 
2337
2347
  // computes silu x/(1+exp(-x)) in single precision vector
@@ -2883,24 +2893,20 @@ struct ggml_state {
2883
2893
 
2884
2894
  // global state
2885
2895
  static struct ggml_state g_state;
2886
- static atomic_int g_state_barrier = 0;
2896
+ static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
2887
2897
 
2888
2898
  // barrier via spin lock
2889
2899
  inline static void ggml_critical_section_start(void) {
2890
- int processing = atomic_fetch_add(&g_state_barrier, 1);
2891
-
2892
- while (processing > 0) {
2893
- // wait for other threads to finish
2894
- atomic_fetch_sub(&g_state_barrier, 1);
2895
- sched_yield(); // TODO: reconsider this
2896
- processing = atomic_fetch_add(&g_state_barrier, 1);
2900
+ while (atomic_flag_test_and_set(&g_state_critical)) {
2901
+ // spin
2902
+ sched_yield();
2897
2903
  }
2898
2904
  }
2899
2905
 
2900
2906
  // TODO: make this somehow automatically executed
2901
2907
  // some sort of "sentry" mechanism
2902
2908
  inline static void ggml_critical_section_end(void) {
2903
- atomic_fetch_sub(&g_state_barrier, 1);
2909
+ atomic_flag_clear(&g_state_critical);
2904
2910
  }
2905
2911
 
2906
2912
  #if defined(__gnu_linux__)
@@ -3216,7 +3222,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3216
3222
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3217
3223
  }
3218
3224
 
3219
- static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3225
+ GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3226
+ return ggml_is_contiguous(tensor);
3227
+ }
3228
+
3229
+ GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3220
3230
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3221
3231
 
3222
3232
  return
@@ -3225,6 +3235,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
3225
3235
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3226
3236
  }
3227
3237
 
3238
+ GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3239
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3240
+
3241
+ return
3242
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
3243
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3244
+ }
3245
+
3228
3246
  GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3229
3247
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3230
3248
 
@@ -3357,10 +3375,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3357
3375
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3358
3376
  }
3359
3377
 
3360
- #if defined(GGML_USE_CLBLAST)
3361
- ggml_cl_init();
3362
- #endif
3363
-
3364
3378
  ggml_setup_op_has_task_pass();
3365
3379
 
3366
3380
  is_first_call = false;
@@ -4882,10 +4896,21 @@ struct ggml_tensor * ggml_repeat_back(
4882
4896
  // ggml_concat
4883
4897
 
4884
4898
  struct ggml_tensor * ggml_concat(
4885
- struct ggml_context* ctx,
4886
- struct ggml_tensor* a,
4887
- struct ggml_tensor* b) {
4888
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4899
+ struct ggml_context * ctx,
4900
+ struct ggml_tensor * a,
4901
+ struct ggml_tensor * b,
4902
+ int dim) {
4903
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4904
+
4905
+ int64_t ne[GGML_MAX_DIMS];
4906
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4907
+ if (d == dim) {
4908
+ ne[d] = a->ne[d] + b->ne[d];
4909
+ continue;
4910
+ }
4911
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4912
+ ne[d] = a->ne[d];
4913
+ }
4889
4914
 
4890
4915
  bool is_node = false;
4891
4916
 
@@ -4893,7 +4918,9 @@ struct ggml_tensor * ggml_concat(
4893
4918
  is_node = true;
4894
4919
  }
4895
4920
 
4896
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4921
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4922
+
4923
+ ggml_set_op_params_i32(result, 0, dim);
4897
4924
 
4898
4925
  result->op = GGML_OP_CONCAT;
4899
4926
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5040,7 @@ struct ggml_tensor * ggml_leaky_relu(
5013
5040
  }
5014
5041
 
5015
5042
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5043
+
5016
5044
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
5017
5045
 
5018
5046
  result->op = GGML_OP_LEAKY_RELU;
@@ -6222,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
6222
6250
  struct ggml_tensor * c,
6223
6251
  int n_dims,
6224
6252
  int mode,
6225
- int n_ctx,
6226
- int n_orig_ctx,
6253
+ int n_ctx_orig,
6227
6254
  float freq_base,
6228
6255
  float freq_scale,
6229
6256
  float ext_factor,
6230
6257
  float attn_factor,
6231
6258
  float beta_fast,
6232
6259
  float beta_slow,
6233
- float xpos_base,
6234
- bool xpos_down,
6235
6260
  bool inplace) {
6236
6261
  GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6237
6262
 
@@ -6252,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
6252
6277
 
6253
6278
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6254
6279
 
6255
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6280
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6256
6281
  memcpy(params + 5, &freq_base, sizeof(float));
6257
6282
  memcpy(params + 6, &freq_scale, sizeof(float));
6258
6283
  memcpy(params + 7, &ext_factor, sizeof(float));
6259
6284
  memcpy(params + 8, &attn_factor, sizeof(float));
6260
6285
  memcpy(params + 9, &beta_fast, sizeof(float));
6261
6286
  memcpy(params + 10, &beta_slow, sizeof(float));
6262
- memcpy(params + 11, &xpos_base, sizeof(float));
6263
- memcpy(params + 12, &xpos_down, sizeof(bool));
6264
6287
  ggml_set_op_params(result, params, sizeof(params));
6265
6288
 
6266
6289
  result->op = GGML_OP_ROPE;
@@ -6277,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
6277
6300
  struct ggml_tensor * a,
6278
6301
  struct ggml_tensor * b,
6279
6302
  int n_dims,
6280
- int mode,
6281
- int n_ctx) {
6303
+ int mode) {
6282
6304
  return ggml_rope_impl(
6283
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6305
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
6284
6306
  );
6285
6307
  }
6286
6308
 
@@ -6289,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
6289
6311
  struct ggml_tensor * a,
6290
6312
  struct ggml_tensor * b,
6291
6313
  int n_dims,
6292
- int mode,
6293
- int n_ctx) {
6314
+ int mode) {
6294
6315
  return ggml_rope_impl(
6295
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6316
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
6296
6317
  );
6297
6318
  }
6298
6319
 
@@ -6303,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
6303
6324
  struct ggml_tensor * c,
6304
6325
  int n_dims,
6305
6326
  int mode,
6306
- int n_ctx,
6307
- int n_orig_ctx,
6327
+ int n_ctx_orig,
6308
6328
  float freq_base,
6309
6329
  float freq_scale,
6310
6330
  float ext_factor,
@@ -6312,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
6312
6332
  float beta_fast,
6313
6333
  float beta_slow) {
6314
6334
  return ggml_rope_impl(
6315
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6316
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6335
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6336
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6317
6337
  );
6318
6338
  }
6319
6339
 
@@ -6324,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6324
6344
  struct ggml_tensor * c,
6325
6345
  int n_dims,
6326
6346
  int mode,
6327
- int n_ctx,
6328
- int n_orig_ctx,
6347
+ int n_ctx_orig,
6329
6348
  float freq_base,
6330
6349
  float freq_scale,
6331
6350
  float ext_factor,
@@ -6333,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6333
6352
  float beta_fast,
6334
6353
  float beta_slow) {
6335
6354
  return ggml_rope_impl(
6336
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6337
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6355
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6356
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6338
6357
  );
6339
6358
  }
6340
6359
 
@@ -6344,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
6344
6363
  struct ggml_tensor * b,
6345
6364
  int n_dims,
6346
6365
  int mode,
6347
- int n_ctx,
6348
- int n_orig_ctx,
6366
+ int n_ctx_orig,
6349
6367
  float freq_base,
6350
6368
  float freq_scale,
6351
6369
  float ext_factor,
@@ -6353,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
6353
6371
  float beta_fast,
6354
6372
  float beta_slow) {
6355
6373
  return ggml_rope_impl(
6356
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6357
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6374
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6375
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6358
6376
  );
6359
6377
  }
6360
6378
 
@@ -6364,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6364
6382
  struct ggml_tensor * b,
6365
6383
  int n_dims,
6366
6384
  int mode,
6367
- int n_ctx,
6368
- int n_orig_ctx,
6385
+ int n_ctx_orig,
6369
6386
  float freq_base,
6370
6387
  float freq_scale,
6371
6388
  float ext_factor,
@@ -6373,8 +6390,8 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6373
6390
  float beta_fast,
6374
6391
  float beta_slow) {
6375
6392
  return ggml_rope_impl(
6376
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6377
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6393
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6394
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6378
6395
  );
6379
6396
  }
6380
6397
 
@@ -6387,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
6387
6404
  struct ggml_tensor * c,
6388
6405
  int n_dims,
6389
6406
  int mode,
6390
- int n_ctx,
6391
- int n_orig_ctx,
6407
+ int n_ctx_orig,
6392
6408
  float freq_base,
6393
6409
  float freq_scale,
6394
6410
  float ext_factor,
6395
6411
  float attn_factor,
6396
6412
  float beta_fast,
6397
- float beta_slow,
6398
- float xpos_base,
6399
- bool xpos_down) {
6413
+ float beta_slow) {
6400
6414
  GGML_ASSERT(ggml_is_vector(b));
6401
6415
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6402
6416
  GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -6412,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
6412
6426
 
6413
6427
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6414
6428
 
6415
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6429
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6416
6430
  memcpy(params + 5, &freq_base, sizeof(float));
6417
6431
  memcpy(params + 6, &freq_scale, sizeof(float));
6418
6432
  memcpy(params + 7, &ext_factor, sizeof(float));
6419
6433
  memcpy(params + 8, &attn_factor, sizeof(float));
6420
6434
  memcpy(params + 9, &beta_fast, sizeof(float));
6421
6435
  memcpy(params + 10, &beta_slow, sizeof(float));
6422
- memcpy(params + 11, &xpos_base, sizeof(float));
6423
- memcpy(params + 12, &xpos_down, sizeof(bool));
6424
6436
  ggml_set_op_params(result, params, sizeof(params));
6425
6437
 
6426
6438
  result->op = GGML_OP_ROPE_BACK;
@@ -9006,17 +9018,6 @@ static void ggml_compute_forward_add_f32(
9006
9018
  const int ith = params->ith;
9007
9019
  const int nth = params->nth;
9008
9020
 
9009
- #ifdef GGML_USE_CLBLAST
9010
- if (src1->backend == GGML_BACKEND_TYPE_GPU) {
9011
- // TODO: OpenCL kernel support full broadcast
9012
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0));
9013
- if (ith == 0) {
9014
- ggml_cl_add(src0, src1, dst);
9015
- }
9016
- return;
9017
- }
9018
- #endif
9019
-
9020
9021
  const int nr = ggml_nrows(src0);
9021
9022
 
9022
9023
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -10124,17 +10125,6 @@ static void ggml_compute_forward_mul_f32(
10124
10125
  const int ith = params->ith;
10125
10126
  const int nth = params->nth;
10126
10127
 
10127
- #if defined(GGML_USE_CLBLAST)
10128
- if (src1->backend == GGML_BACKEND_TYPE_GPU) {
10129
- // TODO: OpenCL kernel support full broadcast
10130
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0));
10131
- if (ith == 0) {
10132
- ggml_cl_mul(src0, src1, dst);
10133
- }
10134
- return;
10135
- }
10136
- #endif
10137
-
10138
10128
  const int64_t nr = ggml_nrows(src0);
10139
10129
 
10140
10130
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -10967,26 +10957,29 @@ static void ggml_compute_forward_concat_f32(
10967
10957
  GGML_ASSERT(nb00 == sizeof(float));
10968
10958
  GGML_ASSERT(nb10 == sizeof(float));
10969
10959
 
10960
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
10961
+
10962
+ GGML_ASSERT(dim >= 0 && dim < 4);
10963
+
10964
+ int64_t o[4] = {0, 0, 0, 0};
10965
+ o[dim] = src0->ne[dim];
10966
+
10967
+ const float * x;
10968
+
10969
+ // TODO: smarter multi-theading
10970
10970
  for (int i3 = 0; i3 < ne3; i3++) {
10971
10971
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10972
- if (i2 < ne02) { // src0
10973
- for (int i1 = 0; i1 < ne1; i1++) {
10974
- for (int i0 = 0; i0 < ne0; i0++) {
10975
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10976
-
10977
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10978
- *y = *x;
10979
- }
10980
- }
10981
- } // src1
10982
- else {
10983
- for (int i1 = 0; i1 < ne1; i1++) {
10984
- for (int i0 = 0; i0 < ne0; i0++) {
10985
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10986
-
10987
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10988
- *y = *x;
10972
+ for (int i1 = 0; i1 < ne1; i1++) {
10973
+ for (int i0 = 0; i0 < ne0; i0++) {
10974
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
10975
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
10976
+ } else {
10977
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10989
10978
  }
10979
+
10980
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
10981
+
10982
+ *y = *x;
10990
10983
  }
10991
10984
  }
10992
10985
  }
@@ -10994,8 +10987,8 @@ static void ggml_compute_forward_concat_f32(
10994
10987
  }
10995
10988
 
10996
10989
  static void ggml_compute_forward_concat(
10997
- const struct ggml_compute_params* params,
10998
- struct ggml_tensor* dst) {
10990
+ const struct ggml_compute_params * params,
10991
+ struct ggml_tensor * dst) {
10999
10992
 
11000
10993
  const struct ggml_tensor * src0 = dst->src[0];
11001
10994
 
@@ -11388,8 +11381,8 @@ static void ggml_compute_forward_gelu_f32(
11388
11381
 
11389
11382
  const struct ggml_tensor * src0 = dst->src[0];
11390
11383
 
11391
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11392
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11384
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11385
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11393
11386
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11394
11387
 
11395
11388
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11451,8 +11444,8 @@ static void ggml_compute_forward_gelu_quick_f32(
11451
11444
 
11452
11445
  const struct ggml_tensor * src0 = dst->src[0];
11453
11446
 
11454
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11455
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11447
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11448
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11456
11449
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11457
11450
 
11458
11451
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11514,8 +11507,8 @@ static void ggml_compute_forward_silu_f32(
11514
11507
 
11515
11508
  const struct ggml_tensor * src0 = dst->src[0];
11516
11509
 
11517
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11518
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11510
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11511
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11519
11512
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11520
11513
 
11521
11514
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11626,9 +11619,9 @@ static void ggml_compute_forward_silu_back_f32(
11626
11619
  const struct ggml_tensor * src0 = dst->src[0];
11627
11620
  const struct ggml_tensor * grad = dst->src[1];
11628
11621
 
11629
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
11630
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11631
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11622
+ GGML_ASSERT(ggml_is_contiguous_1(grad));
11623
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11624
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11632
11625
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11633
11626
  GGML_ASSERT(ggml_are_same_shape(src0, grad));
11634
11627
 
@@ -12367,15 +12360,6 @@ static void ggml_compute_forward_mul_mat(
12367
12360
  // nb01 >= nb00 - src0 is not transposed
12368
12361
  // compute by src0 rows
12369
12362
 
12370
- #if defined(GGML_USE_CLBLAST)
12371
- if (ggml_cl_can_mul_mat(src0, src1, dst)) {
12372
- if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) {
12373
- ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
12374
- }
12375
- return;
12376
- }
12377
- #endif
12378
-
12379
12363
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12380
12364
  if (ggml_compute_forward_mul_mat_use_blas(dst)) {
12381
12365
  const int64_t ne_plane = ne01*ne00;
@@ -12823,8 +12807,6 @@ static void ggml_compute_forward_out_prod_f32(
12823
12807
  // nb01 >= nb00 - src0 is not transposed
12824
12808
  // compute by src0 rows
12825
12809
 
12826
- // TODO: #if defined(GGML_USE_CLBLAST)
12827
-
12828
12810
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12829
12811
  bool use_blas = ggml_is_matrix(src0) &&
12830
12812
  ggml_is_matrix(src1) &&
@@ -13022,7 +13004,7 @@ static void ggml_compute_forward_out_prod_q_f32(
13022
13004
  // nb01 >= nb00 - src0 is not transposed
13023
13005
  // compute by src0 rows
13024
13006
 
13025
- // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
13007
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
13026
13008
 
13027
13009
  if (params->type == GGML_TASK_TYPE_INIT) {
13028
13010
  if (ith != 0) {
@@ -14219,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
14219
14201
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14220
14202
  static void rope_yarn(
14221
14203
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
14222
- float * cos_theta, float * sin_theta
14223
- ) {
14204
+ float * cos_theta, float * sin_theta) {
14224
14205
  // Get n-d rotational scaling corrected for extrapolation
14225
14206
  float theta_interp = freq_scale * theta_extrap;
14226
14207
  float theta = theta_interp;
@@ -14237,18 +14218,19 @@ static void rope_yarn(
14237
14218
 
14238
14219
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
14239
14220
  // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
14240
- static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
14241
- return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14221
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
14222
+ return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14242
14223
  }
14243
14224
 
14244
14225
  static void ggml_rope_cache_init(
14245
- float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14246
- float * cache, float sin_sign, float theta_scale
14247
- ) {
14226
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14227
+ float * cache, float sin_sign, float theta_scale) {
14228
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14248
14229
  float theta = theta_base;
14249
14230
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14231
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
14250
14232
  rope_yarn(
14251
- theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14233
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14252
14234
  );
14253
14235
  cache[i0 + 1] *= sin_sign;
14254
14236
 
@@ -14257,11 +14239,11 @@ static void ggml_rope_cache_init(
14257
14239
  }
14258
14240
 
14259
14241
  GGML_CALL void ggml_rope_yarn_corr_dims(
14260
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
14242
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
14261
14243
  ) {
14262
14244
  // start and end correction dims
14263
- float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
14264
- float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
14245
+ float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
14246
+ float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
14265
14247
  dims[0] = MAX(0, start);
14266
14248
  dims[1] = MIN(n_dims - 1, end);
14267
14249
  }
@@ -14281,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
14281
14263
 
14282
14264
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
14283
14265
 
14284
- // these two only relevant for xPos RoPE:
14285
- float xpos_base;
14286
- bool xpos_down;
14287
-
14288
14266
  //const int n_past = ((int32_t *) dst->op_params)[0];
14289
14267
  const int n_dims = ((int32_t *) dst->op_params)[1];
14290
14268
  const int mode = ((int32_t *) dst->op_params)[2];
14291
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14292
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14269
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14270
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14293
14271
 
14294
14272
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14295
14273
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -14297,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
14297
14275
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
14298
14276
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14299
14277
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14300
- memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
14301
- memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
14302
14278
 
14303
14279
  GGML_TENSOR_UNARY_OP_LOCALS
14304
14280
 
@@ -14326,22 +14302,17 @@ static void ggml_compute_forward_rope_f32(
14326
14302
  int ir = 0;
14327
14303
 
14328
14304
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14329
- const float inv_ndims = -1.f/n_dims;
14305
+
14330
14306
  float corr_dims[2];
14331
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14307
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14332
14308
 
14333
14309
  const bool is_neox = mode & 2;
14334
- const bool is_glm = mode & 4;
14335
14310
 
14336
14311
  const float * freq_factors = NULL;
14337
- if (is_neox) {
14338
- if (src2 != NULL) {
14339
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14340
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14341
- freq_factors = (const float *) src2->data;
14342
- }
14343
- } else {
14344
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14312
+ if (src2 != NULL) {
14313
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14314
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14315
+ freq_factors = (const float *) src2->data;
14345
14316
  }
14346
14317
 
14347
14318
  // backward process uses inverse rotation by cos and sin.
@@ -14356,101 +14327,50 @@ static void ggml_compute_forward_rope_f32(
14356
14327
  const int64_t p = pos[i2];
14357
14328
 
14358
14329
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14359
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14360
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14361
- }
14330
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14362
14331
 
14363
14332
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14364
14333
  if (ir++ < ir0) continue;
14365
14334
  if (ir > ir1) break;
14366
14335
 
14367
- float theta_base = (float)p;
14368
-
14369
- if (is_glm) {
14370
- theta_base = MIN(p, n_ctx - 2);
14371
- float block_theta = MAX(p - (n_ctx - 2), 0);
14372
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14373
- const float cos_theta = cosf(theta_base);
14374
- const float sin_theta = sinf(theta_base) * sin_sign;
14375
- const float cos_block_theta = cosf(block_theta);
14376
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14377
-
14378
- theta_base *= theta_scale;
14379
- block_theta *= theta_scale;
14380
-
14381
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14382
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14383
-
14384
- const float x0 = src[0];
14385
- const float x1 = src[n_dims/2];
14386
- const float x2 = src[n_dims];
14387
- const float x3 = src[n_dims/2*3];
14388
-
14389
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14390
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14391
- dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
14392
- dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
14393
- }
14394
- } else if (!is_neox) {
14395
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14336
+ if (!is_neox) {
14337
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14396
14338
  const float cos_theta = cache[i0 + 0];
14397
14339
  const float sin_theta = cache[i0 + 1];
14398
14340
 
14399
- // zeta scaling for xPos only:
14400
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
14401
- if (xpos_down) zeta = 1.0f / zeta;
14402
-
14403
14341
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14404
14342
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14405
14343
 
14406
14344
  const float x0 = src[0];
14407
14345
  const float x1 = src[1];
14408
14346
 
14409
- dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
14410
- dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14347
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14348
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
14411
14349
  }
14412
14350
  } else {
14413
- // TODO: this might be wrong for ne0 != n_dims - need double check
14414
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14415
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14416
- theta_base *= freq_scale;
14417
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14418
- if (ic < n_dims) {
14419
- const int64_t ib = 0;
14420
-
14421
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14422
- float cur_rot = inv_ndims * ic - ib;
14423
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14351
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14352
+ const int64_t ic = i0/2;
14424
14353
 
14425
- float cos_theta, sin_theta;
14426
- rope_yarn(
14427
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14428
- &cos_theta, &sin_theta
14429
- );
14430
- sin_theta *= sin_sign;
14431
-
14432
- theta_base *= theta_scale;
14433
-
14434
- const int64_t i0 = ib*n_dims + ic/2;
14354
+ const float cos_theta = cache[i0 + 0];
14355
+ const float sin_theta = cache[i0 + 1];
14435
14356
 
14436
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14437
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14357
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14358
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14438
14359
 
14439
- const float x0 = src[0];
14440
- const float x1 = src[n_dims/2];
14360
+ const float x0 = src[0];
14361
+ const float x1 = src[n_dims/2];
14441
14362
 
14442
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14443
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14444
- } else {
14445
- const int64_t i0 = ic;
14363
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14364
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14365
+ }
14366
+ }
14446
14367
 
14447
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14448
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14368
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14369
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14370
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14449
14371
 
14450
- dst_data[0] = src[0];
14451
- dst_data[1] = src[1];
14452
- }
14453
- }
14372
+ dst_data[0] = src[0];
14373
+ dst_data[1] = src[1];
14454
14374
  }
14455
14375
  }
14456
14376
  }
@@ -14476,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
14476
14396
  //const int n_past = ((int32_t *) dst->op_params)[0];
14477
14397
  const int n_dims = ((int32_t *) dst->op_params)[1];
14478
14398
  const int mode = ((int32_t *) dst->op_params)[2];
14479
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14480
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14399
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14400
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14481
14401
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14482
14402
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
14483
14403
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -14511,22 +14431,17 @@ static void ggml_compute_forward_rope_f16(
14511
14431
  int ir = 0;
14512
14432
 
14513
14433
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14514
- const float inv_ndims = -1.f/n_dims;
14434
+
14515
14435
  float corr_dims[2];
14516
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14436
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14517
14437
 
14518
14438
  const bool is_neox = mode & 2;
14519
- const bool is_glm = mode & 4;
14520
14439
 
14521
14440
  const float * freq_factors = NULL;
14522
- if (is_neox) {
14523
- if (src2 != NULL) {
14524
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14525
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14526
- freq_factors = (const float *) src2->data;
14527
- }
14528
- } else {
14529
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14441
+ if (src2 != NULL) {
14442
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14443
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14444
+ freq_factors = (const float *) src2->data;
14530
14445
  }
14531
14446
 
14532
14447
  // backward process uses inverse rotation by cos and sin.
@@ -14541,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
14541
14456
  const int64_t p = pos[i2];
14542
14457
 
14543
14458
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14544
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14545
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14546
- }
14459
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14547
14460
 
14548
14461
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14549
14462
  if (ir++ < ir0) continue;
14550
14463
  if (ir > ir1) break;
14551
14464
 
14552
- float theta_base = (float)p;
14553
-
14554
- if (is_glm) {
14555
- theta_base = MIN(p, n_ctx - 2);
14556
- float block_theta = MAX(p - (n_ctx - 2), 0);
14557
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14558
- const float cos_theta = cosf(theta_base);
14559
- const float sin_theta = sinf(theta_base) * sin_sign;
14560
- const float cos_block_theta = cosf(block_theta);
14561
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14562
-
14563
- theta_base *= theta_scale;
14564
- block_theta *= theta_scale;
14565
-
14566
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14567
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14568
-
14569
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14570
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14571
- const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
14572
- const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
14573
-
14574
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14575
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14576
- dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
14577
- dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
14578
- }
14579
- } else if (!is_neox) {
14580
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14465
+ if (!is_neox) {
14466
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14581
14467
  const float cos_theta = cache[i0 + 0];
14582
14468
  const float sin_theta = cache[i0 + 1];
14583
14469
 
@@ -14591,47 +14477,29 @@ static void ggml_compute_forward_rope_f16(
14591
14477
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14592
14478
  }
14593
14479
  } else {
14594
- // TODO: this might be wrong for ne0 != n_dims - need double check
14595
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14596
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14597
- theta_base *= freq_scale;
14598
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14599
- if (ic < n_dims) {
14600
- const int64_t ib = 0;
14601
-
14602
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14603
- float cur_rot = inv_ndims * ic - ib;
14604
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14605
-
14606
- float cos_theta, sin_theta;
14607
- rope_yarn(
14608
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14609
- &cos_theta, &sin_theta
14610
- );
14611
- sin_theta *= sin_sign;
14612
-
14613
- theta_base *= theta_scale;
14480
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14481
+ const int64_t ic = i0/2;
14614
14482
 
14615
- const int64_t i0 = ib*n_dims + ic/2;
14483
+ const float cos_theta = cache[i0 + 0];
14484
+ const float sin_theta = cache[i0 + 1];
14616
14485
 
14617
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14618
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14486
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14487
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14619
14488
 
14620
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14621
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14489
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
14490
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14622
14491
 
14623
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14624
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14625
- } else {
14626
- const int64_t i0 = ic;
14492
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14493
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14494
+ }
14495
+ }
14627
14496
 
14628
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14629
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14497
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14498
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14499
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14630
14500
 
14631
- dst_data[0] = src[0];
14632
- dst_data[1] = src[1];
14633
- }
14634
- }
14501
+ dst_data[0] = src[0];
14502
+ dst_data[1] = src[1];
14635
14503
  }
14636
14504
  }
14637
14505
  }
@@ -18333,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18333
18201
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18334
18202
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18335
18203
  const int mode = ((int32_t *) tensor->op_params)[2];
18336
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18337
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18338
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18204
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18205
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18206
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18339
18207
 
18340
18208
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18341
18209
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18343,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18343
18211
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18344
18212
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18345
18213
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18346
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18347
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18348
18214
 
18349
18215
  src0->grad = ggml_add_or_set(ctx,
18350
18216
  src0->grad,
@@ -18354,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18354
18220
  src2,
18355
18221
  n_dims,
18356
18222
  mode,
18357
- n_ctx,
18358
- n_orig_ctx,
18223
+ n_ctx_orig,
18359
18224
  freq_base,
18360
18225
  freq_scale,
18361
18226
  ext_factor,
18362
18227
  attn_factor,
18363
18228
  beta_fast,
18364
- beta_slow,
18365
- xpos_base,
18366
- xpos_down),
18229
+ beta_slow),
18367
18230
  zero_table);
18368
18231
  }
18369
18232
  } break;
@@ -18373,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18373
18236
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18374
18237
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18375
18238
  const int mode = ((int32_t *) tensor->op_params)[2];
18376
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18377
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18378
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18239
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18240
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18241
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18379
18242
 
18380
18243
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18381
18244
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18383,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18383
18246
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18384
18247
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18385
18248
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18386
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18387
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18388
18249
 
18389
18250
  src0->grad = ggml_add_or_set(ctx,
18390
18251
  src0->grad,
@@ -18394,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18394
18255
  src2,
18395
18256
  n_dims,
18396
18257
  mode,
18397
- n_ctx,
18398
- n_orig_ctx,
18258
+ n_ctx_orig,
18399
18259
  freq_base,
18400
18260
  freq_scale,
18401
18261
  ext_factor,
18402
18262
  attn_factor,
18403
18263
  beta_fast,
18404
18264
  beta_slow,
18405
- xpos_base,
18406
- xpos_down,
18407
18265
  false),
18408
18266
  zero_table);
18409
18267
  }
@@ -19510,11 +19368,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19510
19368
  {
19511
19369
  const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
19512
19370
 
19513
- #if defined(GGML_USE_CLBLAST)
19514
- if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
19515
- cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
19516
- } else
19517
- #endif
19518
19371
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
19519
19372
  if (ggml_compute_forward_mul_mat_use_blas(node)) {
19520
19373
  if (node->src[0]->type != GGML_TYPE_F32) {
@@ -19644,6 +19497,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19644
19497
  return cplan;
19645
19498
  }
19646
19499
 
19500
+ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
19501
+ enum ggml_status compute_status = GGML_STATUS_SUCCESS;
19502
+
19503
+ #ifdef GGML_USE_OPENMP
19504
+ if (n_threads > 1) {
19505
+ #pragma omp parallel num_threads(n_threads)
19506
+ {
19507
+ #pragma omp single
19508
+ {
19509
+ // update the number of threads from the actual number of threads that we got from OpenMP
19510
+ n_threads = omp_get_num_threads();
19511
+ workers[0].shared->n_threads = n_threads;
19512
+ workers[0].shared->n_active = n_threads;
19513
+ }
19514
+ ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19515
+ }
19516
+ } else {
19517
+ ggml_graph_compute_thread(&workers[0]);
19518
+ }
19519
+ #else
19520
+ // create thread pool
19521
+ if (n_threads > 1) {
19522
+ for (int j = 1; j < n_threads; ++j) {
19523
+ const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
19524
+ GGML_ASSERT(rc == 0);
19525
+ UNUSED(rc);
19526
+ }
19527
+ }
19528
+
19529
+ // this is a work thread too
19530
+ ggml_graph_compute_thread(&workers[0]);
19531
+
19532
+ // join or kill thread pool
19533
+ if (n_threads > 1) {
19534
+ for (int j = 1; j < n_threads; j++) {
19535
+ const int rc = ggml_thread_join(workers[j].thrd, NULL);
19536
+ GGML_ASSERT(rc == 0);
19537
+ UNUSED(rc);
19538
+ }
19539
+ }
19540
+ #endif
19541
+ // don't leave affinity set on the main thread
19542
+ clear_numa_thread_affinity();
19543
+
19544
+ for (int j = 0; j < n_threads; j++) {
19545
+ if (workers[j].ec != GGML_STATUS_SUCCESS) {
19546
+ compute_status = workers[j].ec;
19547
+ break;
19548
+ }
19549
+ }
19550
+ return compute_status;
19551
+ }
19552
+
19647
19553
  enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
19648
19554
  {
19649
19555
  GGML_ASSERT(cplan);
@@ -19654,7 +19560,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19654
19560
  }
19655
19561
  }
19656
19562
 
19657
- const int n_threads = cplan->n_threads;
19563
+ int n_threads = cplan->n_threads;
19564
+
19565
+ #if defined(GGML_USE_OPENMP)
19566
+ n_threads = MIN(n_threads, omp_get_max_threads());
19567
+ #endif
19658
19568
 
19659
19569
  struct ggml_compute_state_shared state_shared = {
19660
19570
  /*.cgraph =*/ cgraph,
@@ -19670,47 +19580,20 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19670
19580
  /*.current_chunk; =*/ 0,
19671
19581
  };
19672
19582
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
19673
-
19674
- // create thread pool
19675
- if (n_threads > 1) {
19676
- for (int j = 1; j < n_threads; ++j) {
19677
- workers[j] = (struct ggml_compute_state) {
19678
- .thrd = 0,
19679
- .ith = j,
19680
- .shared = &state_shared,
19681
- .ec = GGML_STATUS_SUCCESS,
19682
- };
19683
-
19684
- const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
19685
- GGML_ASSERT(rc == 0);
19686
- UNUSED(rc);
19687
- }
19688
- }
19689
-
19690
- workers[0].ith = 0;
19691
- workers[0].shared = &state_shared;
19692
- workers[0].ec = GGML_STATUS_SUCCESS;
19693
-
19694
19583
  const int64_t perf_start_cycles = ggml_perf_cycles();
19695
19584
  const int64_t perf_start_time_us = ggml_perf_time_us();
19696
19585
 
19697
- // this is a work thread too
19698
- ggml_graph_compute_thread(&workers[0]);
19699
- enum ggml_status compute_status = workers[0].ec;
19700
-
19701
- // don't leave affinity set on the main thread
19702
- clear_numa_thread_affinity();
19703
-
19704
- // join or kill thread pool
19705
- if (n_threads > 1) {
19706
- for (int j = 1; j < n_threads; j++) {
19707
- const int rc = ggml_thread_join(workers[j].thrd, NULL);
19708
- GGML_ASSERT(rc == 0);
19709
- if (workers[j].ec != GGML_STATUS_SUCCESS)
19710
- compute_status = workers[j].ec;
19711
- }
19586
+ for (int j = 0; j < n_threads; ++j) {
19587
+ workers[j] = (struct ggml_compute_state) {
19588
+ .thrd = 0,
19589
+ .ith = j,
19590
+ .shared = &state_shared,
19591
+ .ec = GGML_STATUS_SUCCESS,
19592
+ };
19712
19593
  }
19713
19594
 
19595
+ enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
19596
+
19714
19597
  // performance stats (graph)
19715
19598
  {
19716
19599
  int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
@@ -22742,6 +22625,16 @@ int ggml_cpu_has_neon(void) {
22742
22625
  #endif
22743
22626
  }
22744
22627
 
22628
+ int ggml_cpu_has_sve(void) {
22629
+ #if defined(__ARM_FEATURE_SVE)
22630
+ // TODO: Currently, SVE 256 bit is only supported.
22631
+ GGML_ASSERT(svcntb() == QK8_0);
22632
+ return 1;
22633
+ #else
22634
+ return 0;
22635
+ #endif
22636
+ }
22637
+
22745
22638
  int ggml_cpu_has_arm_fma(void) {
22746
22639
  #if defined(__ARM_FEATURE_FMA)
22747
22640
  return 1;
@@ -22783,7 +22676,7 @@ int ggml_cpu_has_wasm_simd(void) {
22783
22676
  }
22784
22677
 
22785
22678
  int ggml_cpu_has_blas(void) {
22786
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL)
22679
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
22787
22680
  return 1;
22788
22681
  #else
22789
22682
  return 0;
@@ -22798,14 +22691,6 @@ int ggml_cpu_has_cuda(void) {
22798
22691
  #endif
22799
22692
  }
22800
22693
 
22801
- int ggml_cpu_has_clblast(void) {
22802
- #if defined(GGML_USE_CLBLAST)
22803
- return 1;
22804
- #else
22805
- return 0;
22806
- #endif
22807
- }
22808
-
22809
22694
  int ggml_cpu_has_vulkan(void) {
22810
22695
  #if defined(GGML_USE_VULKAN)
22811
22696
  return 1;
@@ -22830,9 +22715,16 @@ int ggml_cpu_has_sycl(void) {
22830
22715
  #endif
22831
22716
  }
22832
22717
 
22718
+ int ggml_cpu_has_rpc(void) {
22719
+ #if defined(GGML_USE_RPC)
22720
+ return 1;
22721
+ #else
22722
+ return 0;
22723
+ #endif
22724
+ }
22725
+
22833
22726
  int ggml_cpu_has_gpublas(void) {
22834
- return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
22835
- ggml_cpu_has_sycl();
22727
+ return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl();
22836
22728
  }
22837
22729
 
22838
22730
  int ggml_cpu_has_sse3(void) {