cui-llama.rn 1.4.4 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +54 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1085
  14. package/cpp/chat.h +143 -0
  15. package/cpp/common.cpp +1562 -1996
  16. package/cpp/common.h +677 -744
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-alloc.c +1039 -1030
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +255 -255
  21. package/cpp/ggml-backend-reg.cpp +586 -582
  22. package/cpp/ggml-backend.cpp +2004 -2002
  23. package/cpp/ggml-backend.h +354 -354
  24. package/cpp/ggml-common.h +1857 -1851
  25. package/cpp/ggml-cpp.h +39 -39
  26. package/cpp/ggml-cpu-aarch64.cpp +5725 -4247
  27. package/cpp/ggml-cpu-aarch64.h +8 -8
  28. package/cpp/ggml-cpu-impl.h +512 -380
  29. package/cpp/ggml-cpu-quants.c +13026 -11517
  30. package/cpp/ggml-cpu-traits.cpp +36 -36
  31. package/cpp/ggml-cpu-traits.h +38 -38
  32. package/cpp/ggml-cpu.c +3438 -14485
  33. package/cpp/ggml-cpu.cpp +655 -633
  34. package/cpp/ggml-cpu.h +138 -135
  35. package/cpp/ggml-impl.h +594 -567
  36. package/cpp/ggml-metal-impl.h +312 -3
  37. package/cpp/ggml-metal.h +66 -66
  38. package/cpp/ggml-metal.m +5360 -5002
  39. package/cpp/ggml-opt.cpp +854 -854
  40. package/cpp/ggml-opt.h +216 -216
  41. package/cpp/ggml-quants.c +5238 -5238
  42. package/cpp/ggml-threading.h +14 -14
  43. package/cpp/ggml.c +6618 -6524
  44. package/cpp/ggml.h +2222 -2194
  45. package/cpp/gguf.cpp +1330 -1329
  46. package/cpp/gguf.h +202 -202
  47. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  48. package/cpp/json-schema-to-grammar.h +21 -22
  49. package/cpp/json.hpp +24766 -24766
  50. package/cpp/llama-adapter.cpp +382 -347
  51. package/cpp/llama-adapter.h +76 -74
  52. package/cpp/llama-arch.cpp +1714 -1492
  53. package/cpp/llama-arch.h +428 -402
  54. package/cpp/llama-batch.cpp +368 -368
  55. package/cpp/llama-batch.h +88 -88
  56. package/cpp/llama-chat.cpp +640 -587
  57. package/cpp/llama-chat.h +56 -53
  58. package/cpp/llama-context.cpp +2831 -1775
  59. package/cpp/llama-context.h +265 -128
  60. package/cpp/llama-cparams.cpp +1 -1
  61. package/cpp/llama-cparams.h +38 -37
  62. package/cpp/llama-cpp.h +30 -30
  63. package/cpp/llama-grammar.cpp +1219 -1219
  64. package/cpp/llama-grammar.h +173 -164
  65. package/cpp/llama-graph.cpp +1695 -0
  66. package/cpp/llama-graph.h +592 -0
  67. package/cpp/llama-hparams.cpp +79 -71
  68. package/cpp/llama-hparams.h +156 -139
  69. package/cpp/llama-impl.cpp +167 -167
  70. package/cpp/llama-impl.h +61 -61
  71. package/cpp/llama-io.cpp +15 -0
  72. package/cpp/llama-io.h +35 -0
  73. package/cpp/llama-kv-cache.cpp +1380 -718
  74. package/cpp/llama-kv-cache.h +213 -218
  75. package/cpp/llama-memory.cpp +1 -0
  76. package/cpp/llama-memory.h +21 -0
  77. package/cpp/llama-mmap.cpp +600 -590
  78. package/cpp/llama-mmap.h +68 -68
  79. package/cpp/llama-model-loader.cpp +1129 -1124
  80. package/cpp/llama-model-loader.h +169 -167
  81. package/cpp/llama-model.cpp +13080 -4023
  82. package/cpp/llama-model.h +409 -370
  83. package/cpp/llama-sampling.cpp +2563 -2525
  84. package/cpp/llama-sampling.h +32 -32
  85. package/cpp/llama-vocab.cpp +3295 -3252
  86. package/cpp/llama-vocab.h +125 -125
  87. package/cpp/llama.cpp +351 -10137
  88. package/cpp/llama.h +1434 -1340
  89. package/cpp/log.cpp +427 -423
  90. package/cpp/log.h +132 -132
  91. package/cpp/{chat-template.hpp → minja/chat-template.hpp} +537 -529
  92. package/cpp/{minja.hpp → minja/minja.hpp} +2941 -2883
  93. package/cpp/ops.cpp +8723 -0
  94. package/cpp/ops.h +128 -0
  95. package/cpp/rn-llama.cpp +45 -71
  96. package/cpp/rn-llama.h +3 -3
  97. package/cpp/sampling.cpp +573 -532
  98. package/cpp/sgemm.cpp +3043 -2598
  99. package/cpp/sgemm.h +14 -14
  100. package/cpp/simd-mappings.h +888 -0
  101. package/cpp/speculative.cpp +278 -277
  102. package/cpp/speculative.h +28 -28
  103. package/cpp/unary-ops.cpp +186 -0
  104. package/cpp/unary-ops.h +28 -0
  105. package/cpp/vec.cpp +258 -0
  106. package/cpp/vec.h +802 -0
  107. package/ios/CMakeLists.txt +5 -2
  108. package/ios/RNLlama.mm +2 -2
  109. package/ios/RNLlamaContext.mm +40 -24
  110. package/package.json +1 -1
  111. package/src/NativeRNLlama.ts +6 -4
  112. package/src/index.ts +3 -1
  113. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  114. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  115. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  116. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  117. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  118. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  119. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  120. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  122. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  124. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  125. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  126. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  127. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  128. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  129. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  130. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  131. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  132. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  133. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  134. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  135. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  136. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  194. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  195. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  196. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  197. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  198. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  199. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  200. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  201. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  202. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  203. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  204. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  205. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  206. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  207. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  208. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  209. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  210. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  211. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  212. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  213. package/android/src/main/build-arm64/Makefile +0 -1862
  214. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  215. package/cpp/chat.hpp +0 -55
  216. package/cpp/rn-llama.hpp +0 -913
@@ -0,0 +1,186 @@
1
+ #include "unary-ops.h"
2
+
3
+ static inline float op_abs(float x) {
4
+ return fabsf(x);
5
+ }
6
+
7
+ static inline float op_sgn(float x) {
8
+ return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
9
+ }
10
+
11
+ static inline float op_neg(float x) {
12
+ return -x;
13
+ }
14
+
15
+ static inline float op_step(float x) {
16
+ return (x > 0.f) ? 1.f : 0.f;
17
+ }
18
+
19
+ static inline float op_tanh(float x) {
20
+ return tanhf(x);
21
+ }
22
+
23
+ static inline float op_elu(float x) {
24
+ return (x > 0.f) ? x : expm1f(x);
25
+ }
26
+
27
+ static inline float op_relu(float x) {
28
+ return (x > 0.f) ? x : 0.f;
29
+ }
30
+
31
+ static inline float op_sigmoid(float x) {
32
+ return 1.f / (1.f + expf(-x));
33
+ }
34
+
35
+ static inline float op_hardsigmoid(float x) {
36
+ return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
37
+ }
38
+
39
+ static inline float op_exp(float x) {
40
+ return expf(x);
41
+ }
42
+
43
+ static inline float op_hardswish(float x) {
44
+ return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
45
+ }
46
+
47
+ static inline float op_sqr(float x) {
48
+ return x * x;
49
+ }
50
+
51
+ static inline float op_sqrt(float x) {
52
+ return sqrtf(x);
53
+ }
54
+
55
+ static inline float op_sin(float x) {
56
+ return sinf(x);
57
+ }
58
+
59
+ static inline float op_cos(float x) {
60
+ return cosf(x);
61
+ }
62
+
63
+ static inline float op_log(float x) {
64
+ return logf(x);
65
+ }
66
+
67
+ template <float (*op)(float), typename src0_t, typename dst_t>
68
+ static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
69
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
70
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
71
+
72
+ for (int i = 0; i < n; i++) {
73
+ y[i] = f32_to_dst(op(src0_to_f32(x[i])));
74
+ }
75
+ }
76
+
77
+ template <float (*op)(float), typename src0_t, typename dst_t>
78
+ static void apply_unary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
79
+ const lm_ggml_tensor * src0 = dst->src[0];
80
+
81
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0) && lm_ggml_is_contiguous_1(dst) && lm_ggml_are_same_shape(src0, dst));
82
+
83
+ LM_GGML_TENSOR_UNARY_OP_LOCALS
84
+
85
+ LM_GGML_ASSERT( nb0 == sizeof(dst_t));
86
+ LM_GGML_ASSERT(nb00 == sizeof(src0_t));
87
+
88
+ const auto [ir0, ir1] = get_thread_range(params, src0);
89
+
90
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
91
+ const int64_t i03 = ir/(ne02*ne01);
92
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
93
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
94
+
95
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
96
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
97
+
98
+ vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
99
+ }
100
+ }
101
+
102
+ // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
103
+ template <float (*op)(float)>
104
+ static void unary_op(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
105
+ const lm_ggml_tensor * src0 = dst->src[0];
106
+
107
+ /* */ if (src0->type == LM_GGML_TYPE_F32 && dst->type == LM_GGML_TYPE_F32) { // all f32
108
+ apply_unary_op<op, float, float>(params, dst);
109
+ } else if (src0->type == LM_GGML_TYPE_F16 && dst->type == LM_GGML_TYPE_F16) { // all f16
110
+ apply_unary_op<op, lm_ggml_fp16_t, lm_ggml_fp16_t>(params, dst);
111
+ } else if (src0->type == LM_GGML_TYPE_BF16 && dst->type == LM_GGML_TYPE_BF16) { // all bf16
112
+ apply_unary_op<op, lm_ggml_bf16_t, lm_ggml_bf16_t>(params, dst);
113
+ } else if (src0->type == LM_GGML_TYPE_BF16 && dst->type == LM_GGML_TYPE_F32) {
114
+ apply_unary_op<op, lm_ggml_bf16_t, float>(params, dst);
115
+ } else if (src0->type == LM_GGML_TYPE_F16 && dst->type == LM_GGML_TYPE_F32) {
116
+ apply_unary_op<op, lm_ggml_fp16_t, float>(params, dst);
117
+ } else {
118
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
119
+ lm_ggml_type_name(dst->type), lm_ggml_type_name(src0->type));
120
+ LM_GGML_ABORT("fatal error");
121
+ }
122
+ }
123
+
124
+ void lm_ggml_compute_forward_abs(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
125
+ unary_op<op_abs>(params, dst);
126
+ }
127
+
128
+ void lm_ggml_compute_forward_sgn(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
129
+ unary_op<op_sgn>(params, dst);
130
+ }
131
+
132
+ void lm_ggml_compute_forward_neg(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
133
+ unary_op<op_neg>(params, dst);
134
+ }
135
+
136
+ void lm_ggml_compute_forward_step(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
137
+ unary_op<op_step>(params, dst);
138
+ }
139
+
140
+ void lm_ggml_compute_forward_tanh(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
141
+ unary_op<op_tanh>(params, dst);
142
+ }
143
+
144
+ void lm_ggml_compute_forward_elu(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
145
+ unary_op<op_elu>(params, dst);
146
+ }
147
+
148
+ void lm_ggml_compute_forward_relu(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
149
+ unary_op<op_relu>(params, dst);
150
+ }
151
+
152
+ void lm_ggml_compute_forward_sigmoid(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
153
+ unary_op<op_sigmoid>(params, dst);
154
+ }
155
+
156
+ void lm_ggml_compute_forward_hardsigmoid(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
157
+ unary_op<op_hardsigmoid>(params, dst);
158
+ }
159
+
160
+ void lm_ggml_compute_forward_exp(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
161
+ unary_op<op_exp>(params, dst);
162
+ }
163
+
164
+ void lm_ggml_compute_forward_hardswish(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
165
+ unary_op<op_hardswish>(params, dst);
166
+ }
167
+
168
+ void lm_ggml_compute_forward_sqr(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
169
+ unary_op<op_sqr>(params, dst);
170
+ }
171
+
172
+ void lm_ggml_compute_forward_sqrt(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
173
+ unary_op<op_sqrt>(params, dst);
174
+ }
175
+
176
+ void lm_ggml_compute_forward_sin(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
177
+ unary_op<op_sin>(params, dst);
178
+ }
179
+
180
+ void lm_ggml_compute_forward_cos(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
181
+ unary_op<op_cos>(params, dst);
182
+ }
183
+
184
+ void lm_ggml_compute_forward_log(const lm_ggml_compute_params * params, lm_ggml_tensor * dst) {
185
+ unary_op<op_log>(params, dst);
186
+ }
@@ -0,0 +1,28 @@
1
+ #pragma once
2
+
3
+ #include "cpu-common.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ void lm_ggml_compute_forward_abs(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
10
+ void lm_ggml_compute_forward_sgn(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
11
+ void lm_ggml_compute_forward_neg(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
12
+ void lm_ggml_compute_forward_step(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
13
+ void lm_ggml_compute_forward_tanh(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
14
+ void lm_ggml_compute_forward_elu(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
15
+ void lm_ggml_compute_forward_relu(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
16
+ void lm_ggml_compute_forward_sigmoid(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
17
+ void lm_ggml_compute_forward_hardsigmoid(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
18
+ void lm_ggml_compute_forward_exp(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
19
+ void lm_ggml_compute_forward_hardswish(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
20
+ void lm_ggml_compute_forward_sqr(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
21
+ void lm_ggml_compute_forward_sqrt(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
22
+ void lm_ggml_compute_forward_sin(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
23
+ void lm_ggml_compute_forward_cos(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
24
+ void lm_ggml_compute_forward_log(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst);
25
+
26
+ #ifdef __cplusplus
27
+ }
28
+ #endif
package/cpp/vec.cpp ADDED
@@ -0,0 +1,258 @@
1
+ #include "vec.h"
2
+
3
+ #include <cassert>
4
+
5
+ #if defined(_MSC_VER)
6
+ // disable "possible loss of data" to avoid hundreds of casts
7
+ // we should just be careful :)
8
+ #pragma warning(disable: 4244 4267)
9
+ #endif
10
+
11
+ // precomputed gelu table for f16 (128 KB)
12
+ lm_ggml_fp16_t lm_ggml_table_gelu_f16[1 << 16];
13
+
14
+ // precomputed quick gelu table for f16 (128 KB)
15
+ lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16];
16
+
17
+ void lm_ggml_vec_dot_f32(int n, float * LM_GGML_RESTRICT s, size_t bs, const float * LM_GGML_RESTRICT x, size_t bx, const float * LM_GGML_RESTRICT y, size_t by, int nrc) {
18
+ assert(nrc == 1);
19
+ LM_GGML_UNUSED(nrc);
20
+ LM_GGML_UNUSED(bx);
21
+ LM_GGML_UNUSED(by);
22
+ LM_GGML_UNUSED(bs);
23
+
24
+ #if defined(LM_GGML_SIMD)
25
+ float sumf = 0.0f;
26
+ const int np = (n & ~(LM_GGML_F32_STEP - 1));
27
+
28
+ LM_GGML_F32_VEC sum[LM_GGML_F32_ARR] = { LM_GGML_F32_VEC_ZERO };
29
+
30
+ LM_GGML_F32_VEC ax[LM_GGML_F32_ARR];
31
+ LM_GGML_F32_VEC ay[LM_GGML_F32_ARR];
32
+
33
+ for (int i = 0; i < np; i += LM_GGML_F32_STEP) {
34
+ for (int j = 0; j < LM_GGML_F32_ARR; j++) {
35
+ ax[j] = LM_GGML_F32_VEC_LOAD(x + i + j*LM_GGML_F32_EPR);
36
+ ay[j] = LM_GGML_F32_VEC_LOAD(y + i + j*LM_GGML_F32_EPR);
37
+
38
+ sum[j] = LM_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
39
+ }
40
+ }
41
+
42
+ // reduce sum0..sum3 to sum0
43
+ LM_GGML_F32_VEC_REDUCE(sumf, sum);
44
+
45
+ // leftovers
46
+ for (int i = np; i < n; ++i) {
47
+ sumf += x[i]*y[i];
48
+ }
49
+ #else
50
+ // scalar
51
+ lm_ggml_float sumf = 0.0;
52
+ for (int i = 0; i < n; ++i) {
53
+ sumf += (lm_ggml_float)(x[i]*y[i]);
54
+ }
55
+ #endif
56
+
57
+ *s = sumf;
58
+ }
59
+
60
+ void lm_ggml_vec_dot_bf16(int n, float * LM_GGML_RESTRICT s, size_t bs, lm_ggml_bf16_t * LM_GGML_RESTRICT x, size_t bx, lm_ggml_bf16_t * LM_GGML_RESTRICT y, size_t by, int nrc) {
61
+ assert(nrc == 1);
62
+ LM_GGML_UNUSED(nrc);
63
+ LM_GGML_UNUSED(bx);
64
+ LM_GGML_UNUSED(by);
65
+ LM_GGML_UNUSED(bs);
66
+ int i = 0;
67
+ lm_ggml_float sumf = 0;
68
+
69
+ #if defined(__AVX512BF16__)
70
+ __m512 c1 = _mm512_setzero_ps();
71
+ __m512 c2 = _mm512_setzero_ps();
72
+ for (; i + 64 <= n; i += 64) {
73
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
74
+ m512bh(_mm512_loadu_si512((y + i))));
75
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
76
+ m512bh(_mm512_loadu_si512((y + i + 32))));
77
+ }
78
+ sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1);
79
+ sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2);
80
+
81
+ #elif defined(__AVX512F__)
82
+ #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
83
+ __m512 c1 = _mm512_setzero_ps();
84
+ __m512 c2 = _mm512_setzero_ps();
85
+ for (; i + 32 <= n; i += 32) {
86
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
87
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
88
+ }
89
+ sumf += (lm_ggml_float)_mm512_reduce_add_ps(c1);
90
+ sumf += (lm_ggml_float)_mm512_reduce_add_ps(c2);
91
+
92
+ #undef LOAD
93
+ #elif defined(__AVX2__) || defined(__AVX__)
94
+ #if defined(__AVX2__)
95
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
96
+ #else
97
+ #define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
98
+ #endif
99
+ __m256 c1 = _mm256_setzero_ps();
100
+ __m256 c2 = _mm256_setzero_ps();
101
+ __m256 c3 = _mm256_setzero_ps();
102
+ __m256 c4 = _mm256_setzero_ps();
103
+ for (; i + 32 <= n; i += 32) {
104
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
105
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
106
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
107
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
108
+ }
109
+ __m128 g;
110
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
111
+ _mm256_add_ps(c2, c4));
112
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
113
+ _mm256_castps256_ps128(c1));
114
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
115
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
116
+ sumf += (lm_ggml_float)_mm_cvtss_f32(g);
117
+
118
+ #undef LOAD
119
+ #endif
120
+
121
+ for (; i < n; ++i) {
122
+ sumf += (lm_ggml_float)(LM_GGML_BF16_TO_FP32(x[i]) *
123
+ LM_GGML_BF16_TO_FP32(y[i]));
124
+ }
125
+ *s = sumf;
126
+ }
127
+
128
+ void lm_ggml_vec_dot_f16(int n, float * LM_GGML_RESTRICT s, size_t bs, lm_ggml_fp16_t * LM_GGML_RESTRICT x, size_t bx, lm_ggml_fp16_t * LM_GGML_RESTRICT y, size_t by, int nrc) {
129
+ assert(nrc == 1);
130
+ LM_GGML_UNUSED(nrc);
131
+ LM_GGML_UNUSED(bx);
132
+ LM_GGML_UNUSED(by);
133
+ LM_GGML_UNUSED(bs);
134
+
135
+ lm_ggml_float sumf = 0.0;
136
+
137
+ #if defined(LM_GGML_SIMD)
138
+ const int np = (n & ~(LM_GGML_F16_STEP - 1));
139
+
140
+ LM_GGML_F16_VEC sum[LM_GGML_F16_ARR] = { LM_GGML_F16_VEC_ZERO };
141
+
142
+ LM_GGML_F16_VEC ax[LM_GGML_F16_ARR];
143
+ LM_GGML_F16_VEC ay[LM_GGML_F16_ARR];
144
+
145
+ for (int i = 0; i < np; i += LM_GGML_F16_STEP) {
146
+ for (int j = 0; j < LM_GGML_F16_ARR; j++) {
147
+ ax[j] = LM_GGML_F16_VEC_LOAD(x + i + j*LM_GGML_F16_EPR, j);
148
+ ay[j] = LM_GGML_F16_VEC_LOAD(y + i + j*LM_GGML_F16_EPR, j);
149
+
150
+ sum[j] = LM_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
151
+ }
152
+ }
153
+
154
+ // reduce sum0..sum3 to sum0
155
+ LM_GGML_F16_VEC_REDUCE(sumf, sum);
156
+
157
+ // leftovers
158
+ for (int i = np; i < n; ++i) {
159
+ sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i]));
160
+ }
161
+ #else
162
+ for (int i = 0; i < n; ++i) {
163
+ sumf += (lm_ggml_float)(LM_GGML_FP16_TO_FP32(x[i])*LM_GGML_FP16_TO_FP32(y[i]));
164
+ }
165
+ #endif
166
+
167
+ *s = sumf;
168
+ }
169
+
170
+ void lm_ggml_vec_silu_f32(const int n, float * y, const float * x) {
171
+ int i = 0;
172
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
173
+ for (; i + 15 < n; i += 16) {
174
+ _mm512_storeu_ps(y + i, lm_ggml_v_silu(_mm512_loadu_ps(x + i)));
175
+ }
176
+ #elif defined(__AVX2__) && defined(__FMA__)
177
+ for (; i + 7 < n; i += 8) {
178
+ _mm256_storeu_ps(y + i, lm_ggml_v_silu(_mm256_loadu_ps(x + i)));
179
+ }
180
+ #elif defined(__SSE2__)
181
+ for (; i + 3 < n; i += 4) {
182
+ _mm_storeu_ps(y + i, lm_ggml_v_silu(_mm_loadu_ps(x + i)));
183
+ }
184
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
185
+ for (; i + 3 < n; i += 4) {
186
+ vst1q_f32(y + i, lm_ggml_v_silu(vld1q_f32(x + i)));
187
+ }
188
+ #endif
189
+ for (; i < n; ++i) {
190
+ y[i] = lm_ggml_silu_f32(x[i]);
191
+ }
192
+ }
193
+
194
+ lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
195
+ int i = 0;
196
+ lm_ggml_float sum = 0;
197
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
198
+ for (; i + 15 < n; i += 16) {
199
+ __m512 val = lm_ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
200
+ _mm512_set1_ps(max)));
201
+ _mm512_storeu_ps(y + i, val);
202
+ sum += (lm_ggml_float)_mm512_reduce_add_ps(val);
203
+ }
204
+ #elif defined(__AVX2__) && defined(__FMA__)
205
+ for (; i + 7 < n; i += 8) {
206
+ __m256 val = lm_ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
207
+ _mm256_set1_ps(max)));
208
+ _mm256_storeu_ps(y + i, val);
209
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
210
+ _mm256_castps256_ps128(val));
211
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
212
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
213
+ sum += (lm_ggml_float)_mm_cvtss_f32(val2);
214
+ }
215
+ #elif defined(__SSE2__)
216
+ for (; i + 3 < n; i += 4) {
217
+ __m128 val = lm_ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
218
+ _mm_set1_ps(max)));
219
+ _mm_storeu_ps(y + i, val);
220
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
221
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
222
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
223
+ #else
224
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
225
+ val = _mm_add_ps(val, tmp);
226
+ tmp = _mm_movehl_ps(tmp, val);
227
+ val = _mm_add_ss(val, tmp);
228
+ #endif
229
+ sum += (lm_ggml_float)_mm_cvtss_f32(val);
230
+ }
231
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
232
+ for (; i + 3 < n; i += 4) {
233
+ float32x4_t val = lm_ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
234
+ vdupq_n_f32(max)));
235
+ vst1q_f32(y + i, val);
236
+ sum += (lm_ggml_float)vaddvq_f32(val);
237
+ }
238
+ #endif
239
+ for (; i < n; ++i) {
240
+ float val = expf(x[i] - max);
241
+ sum += (lm_ggml_float)val;
242
+ y[i] = val;
243
+ }
244
+ return sum;
245
+ }
246
+
247
+ lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
248
+ // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
249
+
250
+ int i = 0;
251
+ lm_ggml_float sum = 0;
252
+ for (; i < n; ++i) {
253
+ float val = x[i] - max;
254
+ y[i] = val;
255
+ sum += (lm_ggml_float)expf(val);
256
+ }
257
+ return sum = (lm_ggml_float)logf(sum);
258
+ }