llama_cpp 0.15.3 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -5,6 +5,7 @@
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
7
 
8
+
8
9
  #if defined(_MSC_VER) || defined(__MINGW32__)
9
10
  #include <malloc.h> // using malloc.h with MSC/MINGW
10
11
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -28,6 +29,10 @@
28
29
  #include <syscall.h>
29
30
  #endif
30
31
 
32
+ #ifdef GGML_USE_OPENMP
33
+ #include <omp.h>
34
+ #endif
35
+
31
36
  #ifdef GGML_USE_METAL
32
37
  #include <unistd.h>
33
38
  #endif
@@ -60,6 +65,9 @@
60
65
 
61
66
  typedef volatile LONG atomic_int;
62
67
  typedef atomic_int atomic_bool;
68
+ typedef atomic_int atomic_flag;
69
+
70
+ #define ATOMIC_FLAG_INIT 0
63
71
 
64
72
  static void atomic_store(atomic_int * ptr, LONG val) {
65
73
  InterlockedExchange(ptr, val);
@@ -73,6 +81,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
73
81
  static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
74
82
  return atomic_fetch_add(ptr, -(dec));
75
83
  }
84
+ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
85
+ return InterlockedExchange(ptr, 1);
86
+ }
87
+ static void atomic_flag_clear(atomic_flag * ptr) {
88
+ InterlockedExchange(ptr, 0);
89
+ }
76
90
 
77
91
  typedef HANDLE pthread_t;
78
92
 
@@ -283,17 +297,12 @@ inline static void * ggml_calloc(size_t num, size_t size) {
283
297
 
284
298
  #if defined(GGML_USE_ACCELERATE)
285
299
  #include <Accelerate/Accelerate.h>
286
- #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
287
- #include "ggml-opencl.h"
288
- #endif
289
300
  #elif defined(GGML_USE_OPENBLAS)
290
301
  #if defined(GGML_BLAS_USE_MKL)
291
302
  #include <mkl.h>
292
303
  #else
293
304
  #include <cblas.h>
294
305
  #endif
295
- #elif defined(GGML_USE_CLBLAST)
296
- #include "ggml-opencl.h"
297
306
  #endif
298
307
 
299
308
  // floating point type used to accumulate sums
@@ -1567,11 +1576,11 @@ do { \
1567
1576
 
1568
1577
  // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
1578
 
1570
- #define GGML_F32Cx8 __m256
1579
+ #define GGML_F32Cx8 __m256
1571
1580
  #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
1581
  #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
1582
 
1574
- static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1583
+ static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1575
1584
  float tmp[8];
1576
1585
 
1577
1586
  for (int i = 0; i < 8; i++) {
@@ -1580,13 +1589,14 @@ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1580
1589
 
1581
1590
  return (__m256)__lasx_xvld(tmp, 0);
1582
1591
  }
1583
- static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1592
+ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1584
1593
  float arr[8];
1585
1594
 
1586
1595
  __lasx_xvst(y, arr, 0);
1587
1596
 
1588
- for (int i = 0; i < 8; i++)
1597
+ for (int i = 0; i < 8; i++) {
1589
1598
  x[i] = GGML_FP32_TO_FP16(arr[i]);
1599
+ }
1590
1600
  }
1591
1601
  #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
1602
  #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1662,7 +1672,7 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1662
1672
  #define GGML_F16_STEP 32
1663
1673
  #define GGML_F16_EPR 4
1664
1674
 
1665
- static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1675
+ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1666
1676
  float tmp[4];
1667
1677
 
1668
1678
  tmp[0] = GGML_FP16_TO_FP32(x[0]);
@@ -1673,7 +1683,7 @@ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1673
1683
  return __lsx_vld(tmp, 0);
1674
1684
  }
1675
1685
 
1676
- static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1686
+ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1677
1687
  float arr[4];
1678
1688
 
1679
1689
  __lsx_vst(y, arr, 0);
@@ -1746,7 +1756,7 @@ struct ggml_compute_state_shared {
1746
1756
  int64_t perf_node_start_cycles;
1747
1757
  int64_t perf_node_start_time_us;
1748
1758
 
1749
- const int n_threads;
1759
+ int n_threads;
1750
1760
 
1751
1761
  // synchronization primitives
1752
1762
  atomic_int n_active; // num active threads
@@ -2257,6 +2267,11 @@ inline static float ggml_silu_f32(float x) {
2257
2267
  return x/(1.0f + expf(-x));
2258
2268
  }
2259
2269
 
2270
+ #if __FINITE_MATH_ONLY__
2271
+ #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
2272
+ #error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
2273
+ #endif
2274
+
2260
2275
  #if defined(__ARM_NEON) && defined(__aarch64__)
2261
2276
 
2262
2277
  // adapted from arm limited optimized routine
@@ -2306,32 +2321,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
2306
2321
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
2307
2322
  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2308
2323
  const __m512 n = _mm512_sub_ps(z, r);
2309
- const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2310
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2311
- const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2312
- const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2313
- const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2314
- const __m512 u = _mm512_mul_ps(b, b);
2315
- const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2316
- _mm512_set1_ps(0x1.573e2ep-5f)), u,
2317
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2318
- _mm512_set1_ps(0x1.fffdb6p-2f))),
2319
- u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2320
- if (_mm512_kortestz(c, c))
2321
- return _mm512_fmadd_ps(j, k, k);
2322
- const __m512i g = _mm512_and_si512(
2323
- _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2324
- _mm512_set1_epi32(0x82000000u));
2325
- const __m512 s1 =
2326
- _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2327
- const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2324
+ const __m512 b =
2325
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2326
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2328
2327
  const __mmask16 d =
2329
2328
  _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2330
- return _mm512_mask_blend_ps(
2331
- d, _mm512_mask_blend_ps(
2332
- c, _mm512_fmadd_ps(k, j, k),
2333
- _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2334
- _mm512_mul_ps(s1, s1));
2329
+ const __m512 u = _mm512_mul_ps(b, b);
2330
+ const __m512 j = _mm512_fmadd_ps(
2331
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2332
+ _mm512_set1_ps(0x1.573e2ep-5f)),
2333
+ u,
2334
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2335
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2336
+ u,
2337
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
2338
+ const __m512 res = _mm512_scalef_ps(j, n);
2339
+ if (_mm512_kortestz(d, d))
2340
+ return res;
2341
+ const __m512 zero = _mm512_setzero_ps();
2342
+ const __m512 alt = _mm512_mask_blend_ps(
2343
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
2344
+ return _mm512_mask_blend_ps(d, res, alt);
2335
2345
  }
2336
2346
 
2337
2347
  // computes silu x/(1+exp(-x)) in single precision vector
@@ -2883,24 +2893,20 @@ struct ggml_state {
2883
2893
 
2884
2894
  // global state
2885
2895
  static struct ggml_state g_state;
2886
- static atomic_int g_state_barrier = 0;
2896
+ static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
2887
2897
 
2888
2898
  // barrier via spin lock
2889
2899
  inline static void ggml_critical_section_start(void) {
2890
- int processing = atomic_fetch_add(&g_state_barrier, 1);
2891
-
2892
- while (processing > 0) {
2893
- // wait for other threads to finish
2894
- atomic_fetch_sub(&g_state_barrier, 1);
2895
- sched_yield(); // TODO: reconsider this
2896
- processing = atomic_fetch_add(&g_state_barrier, 1);
2900
+ while (atomic_flag_test_and_set(&g_state_critical)) {
2901
+ // spin
2902
+ sched_yield();
2897
2903
  }
2898
2904
  }
2899
2905
 
2900
2906
  // TODO: make this somehow automatically executed
2901
2907
  // some sort of "sentry" mechanism
2902
2908
  inline static void ggml_critical_section_end(void) {
2903
- atomic_fetch_sub(&g_state_barrier, 1);
2909
+ atomic_flag_clear(&g_state_critical);
2904
2910
  }
2905
2911
 
2906
2912
  #if defined(__gnu_linux__)
@@ -3216,7 +3222,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3216
3222
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3217
3223
  }
3218
3224
 
3219
- static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3225
+ GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3226
+ return ggml_is_contiguous(tensor);
3227
+ }
3228
+
3229
+ GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3220
3230
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3221
3231
 
3222
3232
  return
@@ -3225,6 +3235,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
3225
3235
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3226
3236
  }
3227
3237
 
3238
+ GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3239
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3240
+
3241
+ return
3242
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
3243
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3244
+ }
3245
+
3228
3246
  GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3229
3247
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3230
3248
 
@@ -3357,10 +3375,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3357
3375
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3358
3376
  }
3359
3377
 
3360
- #if defined(GGML_USE_CLBLAST)
3361
- ggml_cl_init();
3362
- #endif
3363
-
3364
3378
  ggml_setup_op_has_task_pass();
3365
3379
 
3366
3380
  is_first_call = false;
@@ -4882,10 +4896,21 @@ struct ggml_tensor * ggml_repeat_back(
4882
4896
  // ggml_concat
4883
4897
 
4884
4898
  struct ggml_tensor * ggml_concat(
4885
- struct ggml_context* ctx,
4886
- struct ggml_tensor* a,
4887
- struct ggml_tensor* b) {
4888
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4899
+ struct ggml_context * ctx,
4900
+ struct ggml_tensor * a,
4901
+ struct ggml_tensor * b,
4902
+ int dim) {
4903
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4904
+
4905
+ int64_t ne[GGML_MAX_DIMS];
4906
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4907
+ if (d == dim) {
4908
+ ne[d] = a->ne[d] + b->ne[d];
4909
+ continue;
4910
+ }
4911
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4912
+ ne[d] = a->ne[d];
4913
+ }
4889
4914
 
4890
4915
  bool is_node = false;
4891
4916
 
@@ -4893,7 +4918,9 @@ struct ggml_tensor * ggml_concat(
4893
4918
  is_node = true;
4894
4919
  }
4895
4920
 
4896
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4921
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4922
+
4923
+ ggml_set_op_params_i32(result, 0, dim);
4897
4924
 
4898
4925
  result->op = GGML_OP_CONCAT;
4899
4926
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5040,7 @@ struct ggml_tensor * ggml_leaky_relu(
5013
5040
  }
5014
5041
 
5015
5042
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5043
+
5016
5044
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
5017
5045
 
5018
5046
  result->op = GGML_OP_LEAKY_RELU;
@@ -6222,16 +6250,13 @@ static struct ggml_tensor * ggml_rope_impl(
6222
6250
  struct ggml_tensor * c,
6223
6251
  int n_dims,
6224
6252
  int mode,
6225
- int n_ctx,
6226
- int n_orig_ctx,
6253
+ int n_ctx_orig,
6227
6254
  float freq_base,
6228
6255
  float freq_scale,
6229
6256
  float ext_factor,
6230
6257
  float attn_factor,
6231
6258
  float beta_fast,
6232
6259
  float beta_slow,
6233
- float xpos_base,
6234
- bool xpos_down,
6235
6260
  bool inplace) {
6236
6261
  GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6237
6262
 
@@ -6252,15 +6277,13 @@ static struct ggml_tensor * ggml_rope_impl(
6252
6277
 
6253
6278
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6254
6279
 
6255
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6280
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6256
6281
  memcpy(params + 5, &freq_base, sizeof(float));
6257
6282
  memcpy(params + 6, &freq_scale, sizeof(float));
6258
6283
  memcpy(params + 7, &ext_factor, sizeof(float));
6259
6284
  memcpy(params + 8, &attn_factor, sizeof(float));
6260
6285
  memcpy(params + 9, &beta_fast, sizeof(float));
6261
6286
  memcpy(params + 10, &beta_slow, sizeof(float));
6262
- memcpy(params + 11, &xpos_base, sizeof(float));
6263
- memcpy(params + 12, &xpos_down, sizeof(bool));
6264
6287
  ggml_set_op_params(result, params, sizeof(params));
6265
6288
 
6266
6289
  result->op = GGML_OP_ROPE;
@@ -6277,10 +6300,9 @@ struct ggml_tensor * ggml_rope(
6277
6300
  struct ggml_tensor * a,
6278
6301
  struct ggml_tensor * b,
6279
6302
  int n_dims,
6280
- int mode,
6281
- int n_ctx) {
6303
+ int mode) {
6282
6304
  return ggml_rope_impl(
6283
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6305
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
6284
6306
  );
6285
6307
  }
6286
6308
 
@@ -6289,10 +6311,9 @@ struct ggml_tensor * ggml_rope_inplace(
6289
6311
  struct ggml_tensor * a,
6290
6312
  struct ggml_tensor * b,
6291
6313
  int n_dims,
6292
- int mode,
6293
- int n_ctx) {
6314
+ int mode) {
6294
6315
  return ggml_rope_impl(
6295
- ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6316
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
6296
6317
  );
6297
6318
  }
6298
6319
 
@@ -6303,8 +6324,7 @@ struct ggml_tensor * ggml_rope_ext(
6303
6324
  struct ggml_tensor * c,
6304
6325
  int n_dims,
6305
6326
  int mode,
6306
- int n_ctx,
6307
- int n_orig_ctx,
6327
+ int n_ctx_orig,
6308
6328
  float freq_base,
6309
6329
  float freq_scale,
6310
6330
  float ext_factor,
@@ -6312,8 +6332,8 @@ struct ggml_tensor * ggml_rope_ext(
6312
6332
  float beta_fast,
6313
6333
  float beta_slow) {
6314
6334
  return ggml_rope_impl(
6315
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6316
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6335
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6336
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6317
6337
  );
6318
6338
  }
6319
6339
 
@@ -6324,8 +6344,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6324
6344
  struct ggml_tensor * c,
6325
6345
  int n_dims,
6326
6346
  int mode,
6327
- int n_ctx,
6328
- int n_orig_ctx,
6347
+ int n_ctx_orig,
6329
6348
  float freq_base,
6330
6349
  float freq_scale,
6331
6350
  float ext_factor,
@@ -6333,8 +6352,8 @@ struct ggml_tensor * ggml_rope_ext_inplace(
6333
6352
  float beta_fast,
6334
6353
  float beta_slow) {
6335
6354
  return ggml_rope_impl(
6336
- ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6337
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6355
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6356
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6338
6357
  );
6339
6358
  }
6340
6359
 
@@ -6344,8 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
6344
6363
  struct ggml_tensor * b,
6345
6364
  int n_dims,
6346
6365
  int mode,
6347
- int n_ctx,
6348
- int n_orig_ctx,
6366
+ int n_ctx_orig,
6349
6367
  float freq_base,
6350
6368
  float freq_scale,
6351
6369
  float ext_factor,
@@ -6353,8 +6371,8 @@ struct ggml_tensor * ggml_rope_custom(
6353
6371
  float beta_fast,
6354
6372
  float beta_slow) {
6355
6373
  return ggml_rope_impl(
6356
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6357
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6374
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6375
+ ext_factor, attn_factor, beta_fast, beta_slow, false
6358
6376
  );
6359
6377
  }
6360
6378
 
@@ -6364,8 +6382,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6364
6382
  struct ggml_tensor * b,
6365
6383
  int n_dims,
6366
6384
  int mode,
6367
- int n_ctx,
6368
- int n_orig_ctx,
6385
+ int n_ctx_orig,
6369
6386
  float freq_base,
6370
6387
  float freq_scale,
6371
6388
  float ext_factor,
@@ -6373,8 +6390,8 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6373
6390
  float beta_fast,
6374
6391
  float beta_slow) {
6375
6392
  return ggml_rope_impl(
6376
- ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6377
- ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6393
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
6394
+ ext_factor, attn_factor, beta_fast, beta_slow, true
6378
6395
  );
6379
6396
  }
6380
6397
 
@@ -6387,16 +6404,13 @@ struct ggml_tensor * ggml_rope_back(
6387
6404
  struct ggml_tensor * c,
6388
6405
  int n_dims,
6389
6406
  int mode,
6390
- int n_ctx,
6391
- int n_orig_ctx,
6407
+ int n_ctx_orig,
6392
6408
  float freq_base,
6393
6409
  float freq_scale,
6394
6410
  float ext_factor,
6395
6411
  float attn_factor,
6396
6412
  float beta_fast,
6397
- float beta_slow,
6398
- float xpos_base,
6399
- bool xpos_down) {
6413
+ float beta_slow) {
6400
6414
  GGML_ASSERT(ggml_is_vector(b));
6401
6415
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6402
6416
  GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -6412,15 +6426,13 @@ struct ggml_tensor * ggml_rope_back(
6412
6426
 
6413
6427
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6414
6428
 
6415
- int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
6429
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
6416
6430
  memcpy(params + 5, &freq_base, sizeof(float));
6417
6431
  memcpy(params + 6, &freq_scale, sizeof(float));
6418
6432
  memcpy(params + 7, &ext_factor, sizeof(float));
6419
6433
  memcpy(params + 8, &attn_factor, sizeof(float));
6420
6434
  memcpy(params + 9, &beta_fast, sizeof(float));
6421
6435
  memcpy(params + 10, &beta_slow, sizeof(float));
6422
- memcpy(params + 11, &xpos_base, sizeof(float));
6423
- memcpy(params + 12, &xpos_down, sizeof(bool));
6424
6436
  ggml_set_op_params(result, params, sizeof(params));
6425
6437
 
6426
6438
  result->op = GGML_OP_ROPE_BACK;
@@ -9006,17 +9018,6 @@ static void ggml_compute_forward_add_f32(
9006
9018
  const int ith = params->ith;
9007
9019
  const int nth = params->nth;
9008
9020
 
9009
- #ifdef GGML_USE_CLBLAST
9010
- if (src1->backend == GGML_BACKEND_TYPE_GPU) {
9011
- // TODO: OpenCL kernel support full broadcast
9012
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0));
9013
- if (ith == 0) {
9014
- ggml_cl_add(src0, src1, dst);
9015
- }
9016
- return;
9017
- }
9018
- #endif
9019
-
9020
9021
  const int nr = ggml_nrows(src0);
9021
9022
 
9022
9023
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -10124,17 +10125,6 @@ static void ggml_compute_forward_mul_f32(
10124
10125
  const int ith = params->ith;
10125
10126
  const int nth = params->nth;
10126
10127
 
10127
- #if defined(GGML_USE_CLBLAST)
10128
- if (src1->backend == GGML_BACKEND_TYPE_GPU) {
10129
- // TODO: OpenCL kernel support full broadcast
10130
- GGML_ASSERT(ggml_can_repeat_rows(src1, src0));
10131
- if (ith == 0) {
10132
- ggml_cl_mul(src0, src1, dst);
10133
- }
10134
- return;
10135
- }
10136
- #endif
10137
-
10138
10128
  const int64_t nr = ggml_nrows(src0);
10139
10129
 
10140
10130
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -10967,26 +10957,29 @@ static void ggml_compute_forward_concat_f32(
10967
10957
  GGML_ASSERT(nb00 == sizeof(float));
10968
10958
  GGML_ASSERT(nb10 == sizeof(float));
10969
10959
 
10960
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
10961
+
10962
+ GGML_ASSERT(dim >= 0 && dim < 4);
10963
+
10964
+ int64_t o[4] = {0, 0, 0, 0};
10965
+ o[dim] = src0->ne[dim];
10966
+
10967
+ const float * x;
10968
+
10969
+ // TODO: smarter multi-theading
10970
10970
  for (int i3 = 0; i3 < ne3; i3++) {
10971
10971
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10972
- if (i2 < ne02) { // src0
10973
- for (int i1 = 0; i1 < ne1; i1++) {
10974
- for (int i0 = 0; i0 < ne0; i0++) {
10975
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10976
-
10977
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10978
- *y = *x;
10979
- }
10980
- }
10981
- } // src1
10982
- else {
10983
- for (int i1 = 0; i1 < ne1; i1++) {
10984
- for (int i0 = 0; i0 < ne0; i0++) {
10985
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10986
-
10987
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10988
- *y = *x;
10972
+ for (int i1 = 0; i1 < ne1; i1++) {
10973
+ for (int i0 = 0; i0 < ne0; i0++) {
10974
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
10975
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
10976
+ } else {
10977
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10989
10978
  }
10979
+
10980
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
10981
+
10982
+ *y = *x;
10990
10983
  }
10991
10984
  }
10992
10985
  }
@@ -10994,8 +10987,8 @@ static void ggml_compute_forward_concat_f32(
10994
10987
  }
10995
10988
 
10996
10989
  static void ggml_compute_forward_concat(
10997
- const struct ggml_compute_params* params,
10998
- struct ggml_tensor* dst) {
10990
+ const struct ggml_compute_params * params,
10991
+ struct ggml_tensor * dst) {
10999
10992
 
11000
10993
  const struct ggml_tensor * src0 = dst->src[0];
11001
10994
 
@@ -11388,8 +11381,8 @@ static void ggml_compute_forward_gelu_f32(
11388
11381
 
11389
11382
  const struct ggml_tensor * src0 = dst->src[0];
11390
11383
 
11391
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11392
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11384
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11385
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11393
11386
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11394
11387
 
11395
11388
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11451,8 +11444,8 @@ static void ggml_compute_forward_gelu_quick_f32(
11451
11444
 
11452
11445
  const struct ggml_tensor * src0 = dst->src[0];
11453
11446
 
11454
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11455
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11447
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11448
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11456
11449
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11457
11450
 
11458
11451
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11514,8 +11507,8 @@ static void ggml_compute_forward_silu_f32(
11514
11507
 
11515
11508
  const struct ggml_tensor * src0 = dst->src[0];
11516
11509
 
11517
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11518
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11510
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11511
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11519
11512
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11520
11513
 
11521
11514
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11626,9 +11619,9 @@ static void ggml_compute_forward_silu_back_f32(
11626
11619
  const struct ggml_tensor * src0 = dst->src[0];
11627
11620
  const struct ggml_tensor * grad = dst->src[1];
11628
11621
 
11629
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
11630
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11631
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11622
+ GGML_ASSERT(ggml_is_contiguous_1(grad));
11623
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11624
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11632
11625
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11633
11626
  GGML_ASSERT(ggml_are_same_shape(src0, grad));
11634
11627
 
@@ -12367,15 +12360,6 @@ static void ggml_compute_forward_mul_mat(
12367
12360
  // nb01 >= nb00 - src0 is not transposed
12368
12361
  // compute by src0 rows
12369
12362
 
12370
- #if defined(GGML_USE_CLBLAST)
12371
- if (ggml_cl_can_mul_mat(src0, src1, dst)) {
12372
- if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) {
12373
- ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
12374
- }
12375
- return;
12376
- }
12377
- #endif
12378
-
12379
12363
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12380
12364
  if (ggml_compute_forward_mul_mat_use_blas(dst)) {
12381
12365
  const int64_t ne_plane = ne01*ne00;
@@ -12823,8 +12807,6 @@ static void ggml_compute_forward_out_prod_f32(
12823
12807
  // nb01 >= nb00 - src0 is not transposed
12824
12808
  // compute by src0 rows
12825
12809
 
12826
- // TODO: #if defined(GGML_USE_CLBLAST)
12827
-
12828
12810
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12829
12811
  bool use_blas = ggml_is_matrix(src0) &&
12830
12812
  ggml_is_matrix(src1) &&
@@ -13022,7 +13004,7 @@ static void ggml_compute_forward_out_prod_q_f32(
13022
13004
  // nb01 >= nb00 - src0 is not transposed
13023
13005
  // compute by src0 rows
13024
13006
 
13025
- // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
13007
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
13026
13008
 
13027
13009
  if (params->type == GGML_TASK_TYPE_INIT) {
13028
13010
  if (ith != 0) {
@@ -14219,8 +14201,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
14219
14201
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14220
14202
  static void rope_yarn(
14221
14203
  float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
14222
- float * cos_theta, float * sin_theta
14223
- ) {
14204
+ float * cos_theta, float * sin_theta) {
14224
14205
  // Get n-d rotational scaling corrected for extrapolation
14225
14206
  float theta_interp = freq_scale * theta_extrap;
14226
14207
  float theta = theta_interp;
@@ -14237,18 +14218,19 @@ static void rope_yarn(
14237
14218
 
14238
14219
  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
14239
14220
  // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
14240
- static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
14241
- return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14221
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
14222
+ return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14242
14223
  }
14243
14224
 
14244
14225
  static void ggml_rope_cache_init(
14245
- float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14246
- float * cache, float sin_sign, float theta_scale
14247
- ) {
14226
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
14227
+ float * cache, float sin_sign, float theta_scale) {
14228
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14248
14229
  float theta = theta_base;
14249
14230
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14231
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
14250
14232
  rope_yarn(
14251
- theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14233
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
14252
14234
  );
14253
14235
  cache[i0 + 1] *= sin_sign;
14254
14236
 
@@ -14257,11 +14239,11 @@ static void ggml_rope_cache_init(
14257
14239
  }
14258
14240
 
14259
14241
  GGML_CALL void ggml_rope_yarn_corr_dims(
14260
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
14242
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
14261
14243
  ) {
14262
14244
  // start and end correction dims
14263
- float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base));
14264
- float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base));
14245
+ float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
14246
+ float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
14265
14247
  dims[0] = MAX(0, start);
14266
14248
  dims[1] = MIN(n_dims - 1, end);
14267
14249
  }
@@ -14281,15 +14263,11 @@ static void ggml_compute_forward_rope_f32(
14281
14263
 
14282
14264
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
14283
14265
 
14284
- // these two only relevant for xPos RoPE:
14285
- float xpos_base;
14286
- bool xpos_down;
14287
-
14288
14266
  //const int n_past = ((int32_t *) dst->op_params)[0];
14289
14267
  const int n_dims = ((int32_t *) dst->op_params)[1];
14290
14268
  const int mode = ((int32_t *) dst->op_params)[2];
14291
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14292
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14269
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14270
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14293
14271
 
14294
14272
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14295
14273
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@@ -14297,8 +14275,6 @@ static void ggml_compute_forward_rope_f32(
14297
14275
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
14298
14276
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14299
14277
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14300
- memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
14301
- memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
14302
14278
 
14303
14279
  GGML_TENSOR_UNARY_OP_LOCALS
14304
14280
 
@@ -14326,22 +14302,17 @@ static void ggml_compute_forward_rope_f32(
14326
14302
  int ir = 0;
14327
14303
 
14328
14304
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14329
- const float inv_ndims = -1.f/n_dims;
14305
+
14330
14306
  float corr_dims[2];
14331
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14307
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14332
14308
 
14333
14309
  const bool is_neox = mode & 2;
14334
- const bool is_glm = mode & 4;
14335
14310
 
14336
14311
  const float * freq_factors = NULL;
14337
- if (is_neox) {
14338
- if (src2 != NULL) {
14339
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14340
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14341
- freq_factors = (const float *) src2->data;
14342
- }
14343
- } else {
14344
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14312
+ if (src2 != NULL) {
14313
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14314
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14315
+ freq_factors = (const float *) src2->data;
14345
14316
  }
14346
14317
 
14347
14318
  // backward process uses inverse rotation by cos and sin.
@@ -14356,101 +14327,50 @@ static void ggml_compute_forward_rope_f32(
14356
14327
  const int64_t p = pos[i2];
14357
14328
 
14358
14329
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14359
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14360
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14361
- }
14330
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14362
14331
 
14363
14332
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14364
14333
  if (ir++ < ir0) continue;
14365
14334
  if (ir > ir1) break;
14366
14335
 
14367
- float theta_base = (float)p;
14368
-
14369
- if (is_glm) {
14370
- theta_base = MIN(p, n_ctx - 2);
14371
- float block_theta = MAX(p - (n_ctx - 2), 0);
14372
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14373
- const float cos_theta = cosf(theta_base);
14374
- const float sin_theta = sinf(theta_base) * sin_sign;
14375
- const float cos_block_theta = cosf(block_theta);
14376
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14377
-
14378
- theta_base *= theta_scale;
14379
- block_theta *= theta_scale;
14380
-
14381
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14382
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14383
-
14384
- const float x0 = src[0];
14385
- const float x1 = src[n_dims/2];
14386
- const float x2 = src[n_dims];
14387
- const float x3 = src[n_dims/2*3];
14388
-
14389
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14390
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14391
- dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta;
14392
- dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
14393
- }
14394
- } else if (!is_neox) {
14395
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14336
+ if (!is_neox) {
14337
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14396
14338
  const float cos_theta = cache[i0 + 0];
14397
14339
  const float sin_theta = cache[i0 + 1];
14398
14340
 
14399
- // zeta scaling for xPos only:
14400
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
14401
- if (xpos_down) zeta = 1.0f / zeta;
14402
-
14403
14341
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14404
14342
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14405
14343
 
14406
14344
  const float x0 = src[0];
14407
14345
  const float x1 = src[1];
14408
14346
 
14409
- dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
14410
- dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14347
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14348
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
14411
14349
  }
14412
14350
  } else {
14413
- // TODO: this might be wrong for ne0 != n_dims - need double check
14414
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14415
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14416
- theta_base *= freq_scale;
14417
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14418
- if (ic < n_dims) {
14419
- const int64_t ib = 0;
14420
-
14421
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14422
- float cur_rot = inv_ndims * ic - ib;
14423
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14351
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14352
+ const int64_t ic = i0/2;
14424
14353
 
14425
- float cos_theta, sin_theta;
14426
- rope_yarn(
14427
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14428
- &cos_theta, &sin_theta
14429
- );
14430
- sin_theta *= sin_sign;
14431
-
14432
- theta_base *= theta_scale;
14433
-
14434
- const int64_t i0 = ib*n_dims + ic/2;
14354
+ const float cos_theta = cache[i0 + 0];
14355
+ const float sin_theta = cache[i0 + 1];
14435
14356
 
14436
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14437
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14357
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14358
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14438
14359
 
14439
- const float x0 = src[0];
14440
- const float x1 = src[n_dims/2];
14360
+ const float x0 = src[0];
14361
+ const float x1 = src[n_dims/2];
14441
14362
 
14442
- dst_data[0] = x0*cos_theta - x1*sin_theta;
14443
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14444
- } else {
14445
- const int64_t i0 = ic;
14363
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
14364
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
14365
+ }
14366
+ }
14446
14367
 
14447
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14448
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14368
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14369
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14370
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14449
14371
 
14450
- dst_data[0] = src[0];
14451
- dst_data[1] = src[1];
14452
- }
14453
- }
14372
+ dst_data[0] = src[0];
14373
+ dst_data[1] = src[1];
14454
14374
  }
14455
14375
  }
14456
14376
  }
@@ -14476,8 +14396,8 @@ static void ggml_compute_forward_rope_f16(
14476
14396
  //const int n_past = ((int32_t *) dst->op_params)[0];
14477
14397
  const int n_dims = ((int32_t *) dst->op_params)[1];
14478
14398
  const int mode = ((int32_t *) dst->op_params)[2];
14479
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14480
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14399
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14400
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14481
14401
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
14482
14402
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
14483
14403
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -14511,22 +14431,17 @@ static void ggml_compute_forward_rope_f16(
14511
14431
  int ir = 0;
14512
14432
 
14513
14433
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14514
- const float inv_ndims = -1.f/n_dims;
14434
+
14515
14435
  float corr_dims[2];
14516
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14436
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
14517
14437
 
14518
14438
  const bool is_neox = mode & 2;
14519
- const bool is_glm = mode & 4;
14520
14439
 
14521
14440
  const float * freq_factors = NULL;
14522
- if (is_neox) {
14523
- if (src2 != NULL) {
14524
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
14525
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14526
- freq_factors = (const float *) src2->data;
14527
- }
14528
- } else {
14529
- GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14441
+ if (src2 != NULL) {
14442
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14443
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14444
+ freq_factors = (const float *) src2->data;
14530
14445
  }
14531
14446
 
14532
14447
  // backward process uses inverse rotation by cos and sin.
@@ -14541,43 +14456,14 @@ static void ggml_compute_forward_rope_f16(
14541
14456
  const int64_t p = pos[i2];
14542
14457
 
14543
14458
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
14544
- if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
14545
- ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14546
- }
14459
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
14547
14460
 
14548
14461
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14549
14462
  if (ir++ < ir0) continue;
14550
14463
  if (ir > ir1) break;
14551
14464
 
14552
- float theta_base = (float)p;
14553
-
14554
- if (is_glm) {
14555
- theta_base = MIN(p, n_ctx - 2);
14556
- float block_theta = MAX(p - (n_ctx - 2), 0);
14557
- for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
14558
- const float cos_theta = cosf(theta_base);
14559
- const float sin_theta = sinf(theta_base) * sin_sign;
14560
- const float cos_block_theta = cosf(block_theta);
14561
- const float sin_block_theta = sinf(block_theta) * sin_sign;
14562
-
14563
- theta_base *= theta_scale;
14564
- block_theta *= theta_scale;
14565
-
14566
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14567
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14568
-
14569
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14570
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14571
- const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
14572
- const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
14573
-
14574
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14575
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14576
- dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
14577
- dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
14578
- }
14579
- } else if (!is_neox) {
14580
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
14465
+ if (!is_neox) {
14466
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14581
14467
  const float cos_theta = cache[i0 + 0];
14582
14468
  const float sin_theta = cache[i0 + 1];
14583
14469
 
@@ -14591,47 +14477,29 @@ static void ggml_compute_forward_rope_f16(
14591
14477
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14592
14478
  }
14593
14479
  } else {
14594
- // TODO: this might be wrong for ne0 != n_dims - need double check
14595
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14596
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14597
- theta_base *= freq_scale;
14598
- for (int64_t ic = 0; ic < ne0; ic += 2) {
14599
- if (ic < n_dims) {
14600
- const int64_t ib = 0;
14601
-
14602
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14603
- float cur_rot = inv_ndims * ic - ib;
14604
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14605
-
14606
- float cos_theta, sin_theta;
14607
- rope_yarn(
14608
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14609
- &cos_theta, &sin_theta
14610
- );
14611
- sin_theta *= sin_sign;
14612
-
14613
- theta_base *= theta_scale;
14480
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
14481
+ const int64_t ic = i0/2;
14614
14482
 
14615
- const int64_t i0 = ib*n_dims + ic/2;
14483
+ const float cos_theta = cache[i0 + 0];
14484
+ const float sin_theta = cache[i0 + 1];
14616
14485
 
14617
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14618
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14486
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
14487
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
14619
14488
 
14620
- const float x0 = GGML_FP16_TO_FP32(src[0]);
14621
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14489
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
14490
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
14622
14491
 
14623
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14624
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14625
- } else {
14626
- const int64_t i0 = ic;
14492
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
14493
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14494
+ }
14495
+ }
14627
14496
 
14628
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14629
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14497
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
14498
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14499
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14630
14500
 
14631
- dst_data[0] = src[0];
14632
- dst_data[1] = src[1];
14633
- }
14634
- }
14501
+ dst_data[0] = src[0];
14502
+ dst_data[1] = src[1];
14635
14503
  }
14636
14504
  }
14637
14505
  }
@@ -18333,9 +18201,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18333
18201
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18334
18202
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18335
18203
  const int mode = ((int32_t *) tensor->op_params)[2];
18336
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18337
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18338
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18204
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18205
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18206
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18339
18207
 
18340
18208
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18341
18209
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18343,8 +18211,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18343
18211
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18344
18212
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18345
18213
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18346
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18347
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18348
18214
 
18349
18215
  src0->grad = ggml_add_or_set(ctx,
18350
18216
  src0->grad,
@@ -18354,16 +18220,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18354
18220
  src2,
18355
18221
  n_dims,
18356
18222
  mode,
18357
- n_ctx,
18358
- n_orig_ctx,
18223
+ n_ctx_orig,
18359
18224
  freq_base,
18360
18225
  freq_scale,
18361
18226
  ext_factor,
18362
18227
  attn_factor,
18363
18228
  beta_fast,
18364
- beta_slow,
18365
- xpos_base,
18366
- xpos_down),
18229
+ beta_slow),
18367
18230
  zero_table);
18368
18231
  }
18369
18232
  } break;
@@ -18373,9 +18236,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18373
18236
  //const int n_past = ((int32_t *) tensor->op_params)[0];
18374
18237
  const int n_dims = ((int32_t *) tensor->op_params)[1];
18375
18238
  const int mode = ((int32_t *) tensor->op_params)[2];
18376
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
18377
- const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
18378
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
18239
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
18240
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
18241
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
18379
18242
 
18380
18243
  memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
18381
18244
  memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
@@ -18383,8 +18246,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18383
18246
  memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
18384
18247
  memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
18385
18248
  memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
18386
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
18387
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
18388
18249
 
18389
18250
  src0->grad = ggml_add_or_set(ctx,
18390
18251
  src0->grad,
@@ -18394,16 +18255,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18394
18255
  src2,
18395
18256
  n_dims,
18396
18257
  mode,
18397
- n_ctx,
18398
- n_orig_ctx,
18258
+ n_ctx_orig,
18399
18259
  freq_base,
18400
18260
  freq_scale,
18401
18261
  ext_factor,
18402
18262
  attn_factor,
18403
18263
  beta_fast,
18404
18264
  beta_slow,
18405
- xpos_base,
18406
- xpos_down,
18407
18265
  false),
18408
18266
  zero_table);
18409
18267
  }
@@ -19510,11 +19368,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19510
19368
  {
19511
19369
  const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
19512
19370
 
19513
- #if defined(GGML_USE_CLBLAST)
19514
- if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
19515
- cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
19516
- } else
19517
- #endif
19518
19371
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
19519
19372
  if (ggml_compute_forward_mul_mat_use_blas(node)) {
19520
19373
  if (node->src[0]->type != GGML_TYPE_F32) {
@@ -19644,6 +19497,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19644
19497
  return cplan;
19645
19498
  }
19646
19499
 
19500
+ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
19501
+ enum ggml_status compute_status = GGML_STATUS_SUCCESS;
19502
+
19503
+ #ifdef GGML_USE_OPENMP
19504
+ if (n_threads > 1) {
19505
+ #pragma omp parallel num_threads(n_threads)
19506
+ {
19507
+ #pragma omp single
19508
+ {
19509
+ // update the number of threads from the actual number of threads that we got from OpenMP
19510
+ n_threads = omp_get_num_threads();
19511
+ workers[0].shared->n_threads = n_threads;
19512
+ workers[0].shared->n_active = n_threads;
19513
+ }
19514
+ ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19515
+ }
19516
+ } else {
19517
+ ggml_graph_compute_thread(&workers[0]);
19518
+ }
19519
+ #else
19520
+ // create thread pool
19521
+ if (n_threads > 1) {
19522
+ for (int j = 1; j < n_threads; ++j) {
19523
+ const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
19524
+ GGML_ASSERT(rc == 0);
19525
+ UNUSED(rc);
19526
+ }
19527
+ }
19528
+
19529
+ // this is a work thread too
19530
+ ggml_graph_compute_thread(&workers[0]);
19531
+
19532
+ // join or kill thread pool
19533
+ if (n_threads > 1) {
19534
+ for (int j = 1; j < n_threads; j++) {
19535
+ const int rc = ggml_thread_join(workers[j].thrd, NULL);
19536
+ GGML_ASSERT(rc == 0);
19537
+ UNUSED(rc);
19538
+ }
19539
+ }
19540
+ #endif
19541
+ // don't leave affinity set on the main thread
19542
+ clear_numa_thread_affinity();
19543
+
19544
+ for (int j = 0; j < n_threads; j++) {
19545
+ if (workers[j].ec != GGML_STATUS_SUCCESS) {
19546
+ compute_status = workers[j].ec;
19547
+ break;
19548
+ }
19549
+ }
19550
+ return compute_status;
19551
+ }
19552
+
19647
19553
  enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
19648
19554
  {
19649
19555
  GGML_ASSERT(cplan);
@@ -19654,7 +19560,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19654
19560
  }
19655
19561
  }
19656
19562
 
19657
- const int n_threads = cplan->n_threads;
19563
+ int n_threads = cplan->n_threads;
19564
+
19565
+ #if defined(GGML_USE_OPENMP)
19566
+ n_threads = MIN(n_threads, omp_get_max_threads());
19567
+ #endif
19658
19568
 
19659
19569
  struct ggml_compute_state_shared state_shared = {
19660
19570
  /*.cgraph =*/ cgraph,
@@ -19670,47 +19580,20 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19670
19580
  /*.current_chunk; =*/ 0,
19671
19581
  };
19672
19582
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
19673
-
19674
- // create thread pool
19675
- if (n_threads > 1) {
19676
- for (int j = 1; j < n_threads; ++j) {
19677
- workers[j] = (struct ggml_compute_state) {
19678
- .thrd = 0,
19679
- .ith = j,
19680
- .shared = &state_shared,
19681
- .ec = GGML_STATUS_SUCCESS,
19682
- };
19683
-
19684
- const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
19685
- GGML_ASSERT(rc == 0);
19686
- UNUSED(rc);
19687
- }
19688
- }
19689
-
19690
- workers[0].ith = 0;
19691
- workers[0].shared = &state_shared;
19692
- workers[0].ec = GGML_STATUS_SUCCESS;
19693
-
19694
19583
  const int64_t perf_start_cycles = ggml_perf_cycles();
19695
19584
  const int64_t perf_start_time_us = ggml_perf_time_us();
19696
19585
 
19697
- // this is a work thread too
19698
- ggml_graph_compute_thread(&workers[0]);
19699
- enum ggml_status compute_status = workers[0].ec;
19700
-
19701
- // don't leave affinity set on the main thread
19702
- clear_numa_thread_affinity();
19703
-
19704
- // join or kill thread pool
19705
- if (n_threads > 1) {
19706
- for (int j = 1; j < n_threads; j++) {
19707
- const int rc = ggml_thread_join(workers[j].thrd, NULL);
19708
- GGML_ASSERT(rc == 0);
19709
- if (workers[j].ec != GGML_STATUS_SUCCESS)
19710
- compute_status = workers[j].ec;
19711
- }
19586
+ for (int j = 0; j < n_threads; ++j) {
19587
+ workers[j] = (struct ggml_compute_state) {
19588
+ .thrd = 0,
19589
+ .ith = j,
19590
+ .shared = &state_shared,
19591
+ .ec = GGML_STATUS_SUCCESS,
19592
+ };
19712
19593
  }
19713
19594
 
19595
+ enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
19596
+
19714
19597
  // performance stats (graph)
19715
19598
  {
19716
19599
  int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles;
@@ -22742,6 +22625,16 @@ int ggml_cpu_has_neon(void) {
22742
22625
  #endif
22743
22626
  }
22744
22627
 
22628
+ int ggml_cpu_has_sve(void) {
22629
+ #if defined(__ARM_FEATURE_SVE)
22630
+ // TODO: Currently, SVE 256 bit is only supported.
22631
+ GGML_ASSERT(svcntb() == QK8_0);
22632
+ return 1;
22633
+ #else
22634
+ return 0;
22635
+ #endif
22636
+ }
22637
+
22745
22638
  int ggml_cpu_has_arm_fma(void) {
22746
22639
  #if defined(__ARM_FEATURE_FMA)
22747
22640
  return 1;
@@ -22783,7 +22676,7 @@ int ggml_cpu_has_wasm_simd(void) {
22783
22676
  }
22784
22677
 
22785
22678
  int ggml_cpu_has_blas(void) {
22786
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL)
22679
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
22787
22680
  return 1;
22788
22681
  #else
22789
22682
  return 0;
@@ -22798,14 +22691,6 @@ int ggml_cpu_has_cuda(void) {
22798
22691
  #endif
22799
22692
  }
22800
22693
 
22801
- int ggml_cpu_has_clblast(void) {
22802
- #if defined(GGML_USE_CLBLAST)
22803
- return 1;
22804
- #else
22805
- return 0;
22806
- #endif
22807
- }
22808
-
22809
22694
  int ggml_cpu_has_vulkan(void) {
22810
22695
  #if defined(GGML_USE_VULKAN)
22811
22696
  return 1;
@@ -22830,9 +22715,16 @@ int ggml_cpu_has_sycl(void) {
22830
22715
  #endif
22831
22716
  }
22832
22717
 
22718
+ int ggml_cpu_has_rpc(void) {
22719
+ #if defined(GGML_USE_RPC)
22720
+ return 1;
22721
+ #else
22722
+ return 0;
22723
+ #endif
22724
+ }
22725
+
22833
22726
  int ggml_cpu_has_gpublas(void) {
22834
- return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
22835
- ggml_cpu_has_sycl();
22727
+ return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl();
22836
22728
  }
22837
22729
 
22838
22730
  int ggml_cpu_has_sse3(void) {